{-# LANGUAGE OverloadedLists #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
-- | Assorted mostly high rank tensor tests.
module TestHighRankSimplified (testTrees) where

import Prelude

import Data.Int (Int64)
import GHC.Exts (IsList (..))
import GHC.TypeLits (KnownNat, type (+), type (-), type (<=))
import Test.Tasty
import Test.Tasty.HUnit hiding (assert)

import Data.Array.Nested qualified as Nested
import Data.Array.Nested.Mixed.Shape
import Data.Array.Nested.Ranked.Shape
import Data.Array.Nested.Shaped.Shape

import HordeAd
import HordeAd.Core.Ops (tD, tfromPrimal)

import CrossTesting
import EqEpsilon

testTrees :: [TestTree]
testTrees =
  [ testCase "3foo" testFoo
  , testCase "3bar" testBar
--  , testCase "3barS" testBarS
  , testCase "3fooD T Double [1.1, 2.2, 3.3]" testFooD
  , testCase "3fooBuild0" testFooBuild0
  , testCase "3fooBuildOut" testFooBuildOut
  , testCase "3fooBuild91" testFooBuild91
  , testCase "3fooBuild92" testFooBuild92
  , testCase "3fooBuild21" testFooBuild21
  , testCase "3fooBuild25" testFooBuild25
  , testCase "3fooBuild21S" testFooBuild21S
  , testCase "3fooBuild25S" testFooBuild25S
  , testCase "3fooBuildNest21S" testFooBuildNest21S
  , testCase "3fooBuildNest25S" testFooBuildNest25S
  , testCase "3fooBuild3" testFooBuild3
  , testCase "3fooBuildDt" testFooBuildDt
  , testCase "3fooBuildDt2" testFooBuildDt2
  , testCase "3fooBuild5" testFooBuild5
  , testCase "3fooBuild1" testFooBuild1
  , testCase "3fooMap" testFooMap
  , testCase "3fooMap1" testFooMap1
  , testCase "3fooNoGo" testFooNoGo
  , testCase "3fooNoGo10" testFooNoGo10
  , testCase "3nestedBuildMap1" testNestedBuildMap1
  , testCase "3nestedBuildMap10" testNestedBuildMap10
  , testCase "3nestedBuildMap11" testNestedBuildMap11
--  , testCase "3nestedBuildMap7" testNestedBuildMap7
  , testCase "3nestedSumBuild1" testNestedSumBuild1
--  , testCase "3nestedSumBuild5" testNestedSumBuild5
  , testCase "3nestedSumBuildB" testNestedSumBuildB
  , testCase "3nestedBuildIndex" testNestedBuildIndex
  , testCase "3barReluADValDt" testBarReluADValDt
  , testCase "3barReluADValDt2" testBarReluADValDt2
  , testCase "3barReluADVal" testBarReluADVal
  , testCase "3barReluADVal3" testBarReluADVal3
  , testCase "3braidedBuilds" testBraidedBuilds
  , testCase "3braidedBuilds1" testBraidedBuilds1
  , testCase "3recycled" testRecycled
-- takes too long (can't be helped)  , testCase "3recycled1" testRecycled1
  , testCase "3concatBuild0" testConcatBuild0
  , testCase "3concatBuild1" testConcatBuild1
  , testCase "3concatBuild0m" testConcatBuild0m
  , testCase "3concatBuild1m" testConcatBuild1m
  , testCase "3concatBuild2" testConcatBuild2
  , testCase "3concatBuild22" testConcatBuild22
  , testCase "3concatBuild3" testConcatBuild3
  , testCase "3logistic0" testLogistic0
  , testCase "3logistic5" testLogistic5
  , testCase "3logistic52" testLogistic52
  , testCase "3logistic0Old" testLogistic0Old
  , testCase "3logistic5Old" testLogistic5Old
  , testCase "3logistic52Old" testLogistic52Old
  , testCase "3logisticA0" testLogisticA0
  , testCase "3logisticB0" testLogisticB0
  , testCase "3logisticC0" testLogisticC0
  ]

foo :: RealFloatH a => (a,a,a) -> a
foo (x,y,z) =
  let w = x * sin y
  in atan2H z w + z * w

_fooF :: RealFloatH a => (a,a,a) -> a
_fooF (x,y,z) =
  let w = x * sin y
  in atan2H z w + z * w

testFoo :: Assertion
testFoo =
  assertEqualUpToEpsilon 1e-3
    (ringestData [2,2,1, 2,2] [-4.6947093,1.5697206,-1.6332961,0.34882763,1.5697206,-1.0,-0.9784988,-0.9158946,6.6326222,3.6699238,7.85237,-2.9069107,17.976654,0.3914159,32.98194,19.807974], ringestData [2,2,1, 2,2] [6.943779,-1.436789,33.67549,0.22397964,-1.436789,-1.0,-0.975235,-0.90365005,147.06645,-73.022705,-9.238474,-10.042692,-980.2843,-7.900571,-14.451739,436.9084], ringestData [2,2,1, 2,2] [-4.8945336,2.067469,-1.7196897,1.3341143,2.067469,1.0,0.99846554,0.99536234,6.6943173,3.7482092,7.977362,-3.1475093,18.000969,0.48736274,33.01224,19.845064])
    (grad (kfromR . rsum0 @5 @(TKScalar Float) . foo) (t16, t16, t16))

bar :: forall a. RealFloatH a => (a, a) -> a
bar (x, y) =
  let w = foo (x, y, x) * sin y
  in atan2H x w + y * w

_barF :: forall a. RealFloatH a => (a, a) -> a
_barF (x, y) =
  let w = _fooF (x, y, x) * sin y
  in atan2H x w + y * w

testBar :: Assertion
testBar =
  assertEqualUpToEpsilon 1e-5
    (Concrete $ Nested.rfromListLinear [3,1,2,2,1,2,2] [304.13867,914.9335,823.0187,1464.4688,5264.3306,1790.0055,1535.4309,3541.6572,304.13867,914.9335,823.0187,1464.4688,6632.4355,6047.113,1535.4309,1346.6815,45.92141,6.4903135,5.5406737,1.4242969,6.4903135,1.1458766,4.6446533,2.3550234,88.783676,27.467598,125.27507,18.177452,647.1915,0.3878851,2177.6152,786.1792,6.4903135,6.4903135,6.4903135,6.4903135,2.3550234,2.3550234,2.3550234,2.3550234,21.783596,2.3550234,2.3550234,2.3550234,21.783596,21.783596,21.783596,21.783596],Concrete $ Nested.rfromListLinear [3,1,2,2,1,2,2] [-5728.7617,24965.113,32825.07,-63505.953,-42592.203,145994.88,-500082.5,-202480.06,-5728.7617,24965.113,32825.07,-63505.953,49494.473,-2446.7632,-500082.5,-125885.58,-43.092484,-1.9601002,-98.97709,2.1931143,-1.9601002,1.8243169,-4.0434446,-1.5266153,2020.9731,-538.0603,-84.28137,62.963814,-34986.996,-9.917454,135.30023,17741.998,-1.9601002,-1.9601002,-1.9601002,-1.9601002,-1.5266153,-1.5266153,-1.5266153,-1.5266153,-4029.1775,-1.5266153,-1.5266153,-1.5266153,-4029.1775,-4029.1775,-4029.1775,-4029.1775])
    (cgrad (kfromR . rsum0 . bar @(ADVal Concrete (TKR 7 Float))) (t48, t48))

{- TODO: divergent result; bring back when GHC 9.10 dropped:
testBarS :: Assertion
testBarS =
  assertEqualUpToEpsilon 1e-5
    (sconcrete $ Nested.sfromListPrimLinear @_ @'[3, 1, 2, 2, 1, 2, 2] knownShS [304.13867,914.9335,823.0187,1464.4688,5264.3306,1790.0055,1535.4309,3541.6572,304.13867,914.9335,823.0187,1464.4688,6632.4355,6047.113,1535.4309,1346.6815,45.92141,6.4903135,5.5406737,1.4242969,6.4903135,1.1458766,4.6446533,2.3550234,88.783676,27.467598,125.27507,18.177452,647.1915,0.3878851,2177.6152,786.1792,6.4903135,6.4903135,6.4903135,6.4903135,2.3550234,2.3550234,2.3550234,2.3550234,21.783596,2.3550234,2.3550234,2.3550234,21.783596,21.783596,21.783596,21.783596], sconcrete $ Nested.sfromListPrimLinear @_ @'[3, 1, 2, 2, 1, 2, 2] knownShS [-5728.761,24965.113,32825.074,-63505.957,-42592.203,145994.89,-500082.5,-202480.05,-5728.761,24965.113,32825.074,-63505.957,49494.473,-2446.7632,-500082.5,-125885.58,-43.092484,-1.9601007,-98.97708,2.1931143,-1.9601007,1.8243167,-4.0434446,-1.5266151,2020.9731,-538.06036,-84.28139,62.963818,-34986.992,-9.917454,135.3003,17741.996,-1.9601007,-1.9601007,-1.9601007,-1.9601007,-1.5266151,-1.5266151,-1.5266151,-1.5266151,-4029.1775,-1.5266151,-1.5266151,-1.5266151,-4029.1775,-4029.1775,-4029.1775,-4029.1775])
    (cgrad (kfromS . ssum0 . barF @(ADVal Concrete (TKS '[3, 1, 2, 2, 1, 2, 2] Float))) (sfromR t48, sfromR t48))
-}

-- A dual-number and list-based version of a function that goes
-- from `R^3` to `R`.
fooD :: forall r n. (RealFloatH (ADVal Concrete (TKR n r)))
     => ListR 3 (ADVal Concrete (TKR n r)) -> ADVal Concrete (TKR n r)
fooD (x ::: y ::: z ::: ZR) =
  let w = x * sin y
  in atan2H z w + z * w

testFooD :: Assertion
testFooD =
  assertEqualUpToEpsilon 1e-10
    (fromList [ringestData [1,2,2,1,2,2,2,2,2,1] [18.73108960474591,20.665204824764675,25.821775835995922,18.666613887422585,34.775664100213014,62.54884873632415,37.93303229694526,11.635186977032971,18.73108960474591,20.665204824764675,25.821775835995922,20.600738734367262,34.775664100213014,62.54884873632415,16.663997008808924,3.1300339898598155,1.060799258653783,3.78942741815228,0.1889454555944933,-1.060799258653783,62.54884873632415,37.93303229694526,62.54884873632415,35.99996432769119,62.54884873632415,37.93303229694526,11.635186977032971,18.73108960474591,20.665204824764675,25.821775835995922,20.665204824764675,20.665204824764675,25.821775835995922,34.134947381491145,34.775664100213014,45527.22315787758,-4.488300547708207,2.1475176207684497,8.404498097344806,5.747373381623309,5.096832468946128,-2.4630526910399646,18.666613887422585,1.7769486222994448,-215.8115662030395,16.73214939773215,1.060799258653783,1.060799258653783,1.060799258653783,1.060799258653783,2.1475176207684497,2.1475176207684497,2.1475176207684497,2.1475176207684497,16.08742477551077,16.08742477551077,16.08742477551077,16.08742477551077,2.1475176207684497,2.1475176207684497,2.1475176207684497,2.1475176207684497,16.08742477551077,16.08742477551077,16.08742477551077,16.08742477551077,25.821775835995922,5.096832468946128,7.045006174919766,-1.7808956511653404,16.663997008744435,18.533999054066836,-25.177267779903083,16.60317012020362,25.821775835995922,5.096832468946128,7.045006174919766,-1.7808956511653404,16.663997008744435,18.533999054066836,-12.280721583745471,16.60317012020362,5.161956818274285,18.73108960474591,20.665204824764675,25.821775835995922,20.665204824764675,25.821775835995922,188.11000552192755,34.775664100213014,62.54884873632415,35.99996432769119,62.54884873632415,55.32933980086011,62.54884873632415,55.32933980086011,11.635186977032971,18.73108960474591,20.665204824764675,25.821775835995922,20.665204824764675,25.821775835995922,20.665204824764675,25.821775835995922,14.152094926881784,34.775664100213014,62.54884873632415,53.39649491503442,62.54884873632415,14.72904006548922,62.54884873632415,37.93303229694526,11.635186977032971,18.73108960474591,20.665204824764675,25.821775835995922,20.665204824764675,25.821775835995922,20.665204824764675,25.821775835995922,57.33025874582143,34.775664100213014,62.54884873632415,36.64432517917614,62.54884873632415,34.06684929392724,62.54884873632415,35.99996432769119], ringestData [1,2,2,1,2,2,2,2,2,1] [647.1354943759653,787.5605199613974,1229.333367336918,642.6917612678424,2229.2701397674327,7210.705208776531,2652.3459120285806,250.02943073785886,647.1354943759653,787.5605199613974,1229.333367336918,782.6578815409038,2229.2701397674327,7210.705208776531,512.2982591657892,18.580536443699742,2.518850510725482,26.993800503829114,0.2243239488720164,2.518850510725482,7210.705208776531,2652.3459120285806,7210.705208776531,2388.9603285490866,7210.705208776531,2652.3459120285806,250.02943073785886,647.1354943759653,787.5605199613974,1229.333367336918,787.5605199613974,787.5605199613974,1229.333367336918,2147.9011858437157,2229.2701397674327,-0.5405182383359878,-0.5328698165396271,-0.5099245509210925,130.7140495214786,61.4116989316311,48.40938174779479,11.696956758139343,642.6917612678424,6.317020301049852,85833.87394976329,516.4928003659018,2.518850510725482,2.518850510725482,2.518850510725482,2.518850510725482,-0.5099245509210925,-0.5099245509210925,-0.5099245509210925,-0.5099245509210925,477.4973215160379,477.4973215160379,477.4973215160379,477.4973215160379,-0.5099245509210925,-0.5099245509210925,-0.5099245509210925,-0.5099245509210925,477.4973215160379,477.4973215160379,477.4973215160379,477.4973215160379,1229.333367336918,48.40938174779479,92.00538642301063,6.3430614471479245,512.2982591618282,633.5999783697488,1168.7578661039847,508.56903530563443,1229.333367336918,48.40938174779479,92.00538642301063,6.3430614471479245,512.2982591618282,633.5999783697488,278.48156010484087,508.56903530563443,49.64077766932281,647.1354943759653,787.5605199613974,1229.333367336918,787.5605199613974,1229.333367336918,65212.963738386214,2229.2701397674327,7210.705208776531,2388.9603285490866,7210.705208776531,5642.338335044463,7210.705208776531,5642.338335044463,250.02943073785886,647.1354943759653,787.5605199613974,1229.333367336918,787.5605199613974,1229.333367336918,787.5605199613974,1229.333367336918,369.6431004072799,2229.2701397674327,7210.705208776531,5255.048317224881,7210.705208776531,400.3514287686239,7210.705208776531,2652.3459120285806,250.02943073785886,647.1354943759653,787.5605199613974,1229.333367336918,787.5605199613974,1229.333367336918,787.5605199613974,1229.333367336918,6057.774447242021,2229.2701397674327,7210.705208776531,2475.225838667682,7210.705208776531,2139.3419044407133,7210.705208776531,2388.9603285490866], ringestData [1,2,2,1,2,2,2,2,2,1] [18.76237979248771,20.69357069589509,25.8444826804669,18.698011972363496,34.7925278085306,62.558226125235436,37.948492946856575,11.685493300971446,18.76237979248771,20.69357069589509,25.8444826804669,20.629193248844963,34.7925278085306,62.558226125235436,16.699160877305292,3.3121428825170947,1.516071490296981,3.9411848287000124,1.0994899188808887,-1.516071490296981,62.558226125235436,37.948492946856575,62.558226125235436,36.01625479268449,62.558226125235436,37.948492946856575,11.685493300971446,18.76237979248771,20.69357069589509,25.8444826804669,20.69357069589509,20.69357069589509,25.8444826804669,34.1521274657041,34.7925278085306,-45527.22317076194,4.617144085155745,-2.4052046956635262,8.474005308282699,5.84854498865513,5.210650526856928,-2.6906888068615635,18.698011972363496,2.0810391881996813,-215.8142842462135,16.767170338627782,1.516071490296981,1.516071490296981,1.516071490296981,1.516071490296981,-2.4052046956635262,-2.4052046956635262,-2.4052046956635262,-2.4052046956635262,16.123846116986126,16.123846116986126,16.123846116986126,16.123846116986126,-2.4052046956635262,-2.4052046956635262,-2.4052046956635262,-2.4052046956635262,16.123846116986126,16.123846116986126,16.123846116986126,16.123846116986126,25.8444826804669,5.210650526856928,7.127782944309438,-2.0844104722608057,16.69916087724094,18.565621417897145,-25.200555362084323,16.638462541261234,25.8444826804669,5.210650526856928,7.127782944309438,-2.0844104722608057,16.69916087724094,18.565621417897145,-12.328394068734287,16.638462541261234,5.2743697149763085,18.76237979248771,20.69357069589509,25.8444826804669,20.69357069589509,25.8444826804669,188.113123824884,34.7925278085306,62.558226125235436,36.01625479268449,62.558226125235436,55.33994055377702,62.558226125235436,55.33994055377702,11.685493300971446,18.76237979248771,20.69357069589509,25.8444826804669,20.69357069589509,25.8444826804669,20.69357069589509,25.8444826804669,14.193483311576621,34.7925278085306,62.558226125235436,53.40747931617656,62.558226125235436,14.768811697198851,62.558226125235436,37.948492946856575,11.685493300971446,18.76237979248771,20.69357069589509,25.8444826804669,20.69357069589509,25.8444826804669,20.69357069589509,25.8444826804669,57.34048958248757,34.7925278085306,62.558226125235436,36.660329315674915,62.558226125235436,34.08406370302229,62.558226125235436,36.01625479268449]])
    (cgrad (kfromR . rsum0 . fooD) (fromList [ t128
               , rreplicate0N [1, 2, 2, 1, 2, 2, 2, 2, 2, 1] (rscalar (0.7 :: Double))
               , t128 ]))

fooBuild0 :: forall target r n. (ADReady target, GoodScalar r, KnownNat n)
          => target (TKR (1 + n) r) -> target (TKR (1 + n) r)
fooBuild0 v =
  let r = rsum v
  in rbuild1 2 $ const r

testFooBuild0 :: Assertion
testFooBuild0 =
  assertEqualUpToEpsilon' 1e-10
    (ringestData [2,2,1,2,2] [2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0])
    (rev' @Double @5 fooBuild0 t16)

fooBuildOut
  :: forall target r n. (ADReady target, GoodScalar r, KnownNat n)
  => target (TKR (1 + n) r) -> target (TKR (1 + n) r)
fooBuildOut v =
  rbuild1 2 $ \ix -> ifH (ix ==. 0)
                         (rindex v [ix + 1])  -- index out of bounds; guarded
                         (rsum v)

testFooBuildOut :: Assertion
testFooBuildOut =
  assertEqualUpToEpsilon' 1e-10
    (ringestData [2,2,1,2,2] [1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0])
    (rev' @Double @5 fooBuildOut t16)

fooBuild2
  :: forall target r n.
     (ADReady target, GoodScalar r, KnownNat n, Floating (target (TKR n r)), RealFloat r)
  => target (TKR (1 + n) r) -> target (TKR (1 + n) r)
fooBuild2 v =
  rbuild1 2 $ \ix' -> let ix :: PrimalOf target (TKS '[] Int64)
                          ix = sfromR $ rfromK ix' in
    ifH (ix - (sprimalPart . sfloor . sfromR) (rsum0  @5 @(TKScalar r)
                      $ rreplicate0N [5,12,11,9,4] (rsum0 v)) - sscalar 10001 >=. sscalar 0
         &&* ix - (sprimalPart . sfloor . sfromR) (rsum0 @5 @(TKScalar r) @target
                          $ rreplicate0N [5,12,11,9,4] (rsum0 v)) - sscalar 10001 <=. sscalar 1)
        (rindex v [kfromR $ rfromS $ ix - (sprimalPart . sfloor . sfromR) (rsum0  @5 @(TKScalar r) @target
                                $ rreplicate0N [5,12,11,9,4] (rsum0 v)) - sscalar 10001])
           -- index out of bounds; also fine
        (sqrt $ abs $ rindex v [kfromS
                                $ let rr = (ix - (sfromR . rprimalPart . rfloor) (rsum0 v) - sscalar 10001) `remH` sscalar 2
                                  in ifH (signum rr ==. negate (signum $ sscalar 2))
                                     (rr + sscalar 2)
                                     rr])

fooBuild2L
  :: forall k target r n.
     (ADReady target, GoodScalar r, KnownNat n, Floating (target (TKR n r)), RealFloat r)
  => ListR k (target (TKR (1 + n) r)) -> target (TKR (1 + n) r)
fooBuild2L = foldr1 (+) . fmap fooBuild2

testFooBuild91 :: Assertion
testFooBuild91 =
  assertEqualUpToEpsilon 1e-8
    (fromList $ map (ringestData [2]) [[1.5811388300841895,1.118033988749895],[1.118033988749895,0.9128709291752769],[0.9128709291752769,0.7905694150420948],[0.7905694150420948,0.7071067811865475],[0.7071067811865475,0.6454972243679028],[0.6454972243679028,0.5976143046671968],[0.5976143046671968,0.5590169943749475],[0.5590169943749475,0.5270462766947299],[0.5270462766947299,1.5811388300841895],[1.5811388300841895,0.5270462766947299],[1.5811388300841895,1.118033988749895],[1.118033988749895,0.9128709291752769],[0.9128709291752769,0.7905694150420948],[0.7905694150420948,0.7071067811865475],[0.7071067811865475,0.6454972243679028],[0.6454972243679028,0.5976143046671968],[0.5976143046671968,0.5590169943749475],[0.5590169943749475,0.5270462766947299],[0.5270462766947299,1.5811388300841895],[1.5811388300841895,0.5270462766947299],[1.5811388300841895,1.118033988749895],[1.118033988749895,0.9128709291752769],[0.9128709291752769,0.7905694150420948],[0.7905694150420948,0.7071067811865475],[0.7071067811865475,0.6454972243679028],[0.6454972243679028,0.5976143046671968],[0.5976143046671968,0.5590169943749475],[0.5590169943749475,0.5270462766947299],[0.5270462766947299,1.5811388300841895],[1.5811388300841895,0.5270462766947299],[1.5811388300841895,1.118033988749895],[1.118033988749895,0.9128709291752769],[0.9128709291752769,0.7905694150420948],[0.7905694150420948,0.7071067811865475],[0.7071067811865475,0.6454972243679028],[0.6454972243679028,0.5976143046671968],[0.5976143046671968,0.5590169943749475],[0.5590169943749475,0.5270462766947299],[0.5270462766947299,1.5811388300841895],[1.5811388300841895,0.5270462766947299],[1.5811388300841895,1.118033988749895],[1.118033988749895,0.9128709291752769],[0.9128709291752769,0.7905694150420948],[0.7905694150420948,0.7071067811865475],[0.7071067811865475,0.6454972243679028],[0.6454972243679028,0.5976143046671968],[0.5976143046671968,0.5590169943749475],[0.5590169943749475,0.5270462766947299],[0.5270462766947299,1.5811388300841895],[1.5811388300841895,0.5270462766947299]])
    (cgrad (kfromR . rsum0 @1 . fooBuild2L @50 @(ADVal Concrete) @Double @0)
       (fromList $ map (ringestData [2]) [[0.1, 0.2], [0.2, 0.3], [0.3, 0.4], [0.4, 0.5], [0.5, 0.6], [0.6, 0.7], [0.7, 0.8], [0.8, 0.9], [0.9, 0.1], [0.1, 0.9], [0.1, 0.2], [0.2, 0.3], [0.3, 0.4], [0.4, 0.5], [0.5, 0.6], [0.6, 0.7], [0.7, 0.8], [0.8, 0.9], [0.9, 0.1], [0.1, 0.9], [0.1, 0.2], [0.2, 0.3], [0.3, 0.4], [0.4, 0.5], [0.5, 0.6], [0.6, 0.7], [0.7, 0.8], [0.8, 0.9], [0.9, 0.1], [0.1, 0.9], [0.1, 0.2], [0.2, 0.3], [0.3, 0.4], [0.4, 0.5], [0.5, 0.6], [0.6, 0.7], [0.7, 0.8], [0.8, 0.9], [0.9, 0.1], [0.1, 0.9], [0.1, 0.2], [0.2, 0.3], [0.3, 0.4], [0.4, 0.5], [0.5, 0.6], [0.6, 0.7], [0.7, 0.8], [0.8, 0.9], [0.9, 0.1], [0.1, 0.9]]))

testFooBuild92 :: Assertion
testFooBuild92 =
  assertEqualUpToEpsilon 1e-8
    (fromList $ map (ringestData [2]) [[1.5811388300841895,1.118033988749895],[1.118033988749895,0.9128709291752769],[0.9128709291752769,0.7905694150420948],[0.7905694150420948,0.7071067811865475],[0.7071067811865475,0.6454972243679028],[0.6454972243679028,0.5976143046671968],[0.5976143046671968,0.5590169943749475],[0.5590169943749475,0.5270462766947299],[0.5270462766947299,1.5811388300841895],[1.5811388300841895,0.5270462766947299],[1.5811388300841895,1.118033988749895],[1.118033988749895,0.9128709291752769],[0.9128709291752769,0.7905694150420948],[0.7905694150420948,0.7071067811865475],[0.7071067811865475,0.6454972243679028],[0.6454972243679028,0.5976143046671968],[0.5976143046671968,0.5590169943749475],[0.5590169943749475,0.5270462766947299],[0.5270462766947299,1.5811388300841895],[1.5811388300841895,0.5270462766947299],[1.5811388300841895,1.118033988749895],[1.118033988749895,0.9128709291752769],[0.9128709291752769,0.7905694150420948],[0.7905694150420948,0.7071067811865475],[0.7071067811865475,0.6454972243679028],[0.6454972243679028,0.5976143046671968],[0.5976143046671968,0.5590169943749475],[0.5590169943749475,0.5270462766947299],[0.5270462766947299,1.5811388300841895],[1.5811388300841895,0.5270462766947299],[1.5811388300841895,1.118033988749895],[1.118033988749895,0.9128709291752769],[0.9128709291752769,0.7905694150420948],[0.7905694150420948,0.7071067811865475],[0.7071067811865475,0.6454972243679028],[0.6454972243679028,0.5976143046671968],[0.5976143046671968,0.5590169943749475],[0.5590169943749475,0.5270462766947299],[0.5270462766947299,1.5811388300841895],[1.5811388300841895,0.5270462766947299],[1.5811388300841895,1.118033988749895],[1.118033988749895,0.9128709291752769],[0.9128709291752769,0.7905694150420948],[0.7905694150420948,0.7071067811865475],[0.7071067811865475,0.6454972243679028],[0.6454972243679028,0.5976143046671968],[0.5976143046671968,0.5590169943749475],[0.5590169943749475,0.5270462766947299],[0.5270462766947299,1.5811388300841895],[1.5811388300841895,0.5270462766947299]])
    (grad
       (kfromR . rsum0 @1 . fooBuild2L @50 @(AstTensor AstMethodLet FullSpan) @Double @0)
       (fromList $ map (ringestData [2]) [[0.1, 0.2], [0.2, 0.3], [0.3, 0.4], [0.4, 0.5], [0.5, 0.6], [0.6, 0.7], [0.7, 0.8], [0.8, 0.9], [0.9, 0.1], [0.1, 0.9], [0.1, 0.2], [0.2, 0.3], [0.3, 0.4], [0.4, 0.5], [0.5, 0.6], [0.6, 0.7], [0.7, 0.8], [0.8, 0.9], [0.9, 0.1], [0.1, 0.9], [0.1, 0.2], [0.2, 0.3], [0.3, 0.4], [0.4, 0.5], [0.5, 0.6], [0.6, 0.7], [0.7, 0.8], [0.8, 0.9], [0.9, 0.1], [0.1, 0.9], [0.1, 0.2], [0.2, 0.3], [0.3, 0.4], [0.4, 0.5], [0.5, 0.6], [0.6, 0.7], [0.7, 0.8], [0.8, 0.9], [0.9, 0.1], [0.1, 0.9], [0.1, 0.2], [0.2, 0.3], [0.3, 0.4], [0.4, 0.5], [0.5, 0.6], [0.6, 0.7], [0.7, 0.8], [0.8, 0.9], [0.9, 0.1], [0.1, 0.9]]))

testFooBuild21 :: Assertion
testFooBuild21 =
  assertEqualUpToEpsilon' 1e-10
    (ringestData [2] [0.2886751345948129,0.35355339059327373])
    (rev' @Double @1 fooBuild2 (ringestData [2] [3.0,2.0]))

testFooBuild25 :: Assertion
testFooBuild25 =
  assertEqualUpToEpsilon' 1e-10
    (ringestData [2,2,1,2,2] [0.22360679774997896,0.35355339059327373,0.20412414523193154,0.5,-0.35355339059327373,500.0,1.5811388300841895,-1.118033988749895,0.1381447409988844,0.16666666666666666,0.17677669529663687,-0.25,8.574929257125441e-2,0.288948802391873,-8.703882797784893e-2,9.805806756909202e-2])
    (rev' @Double @5 fooBuild2 t16)

fooBuild2S
  :: forall k sh target r.
     (ADReady target, GoodScalar r, KnownNat k, Floating (target (TKS sh r)), RealFloat r, KnownShS sh)
  => target (TKS (k : sh) r) -> target (TKR (1 + Rank sh) r)
fooBuild2S v = rfromS $
  sbuild1 @2 $ \ix' -> let ix :: PrimalOf target (TKS '[] Int64)
                           ix = sfromR $ rfromK ix' in
    ifH (ix - (sprimalPart . sfloor) (ssum0 @[5,12,11,9,4] @(TKScalar r)
             $ sreplicate0N @[5,12,11,9,4] (ssum0 v)) - srepl 10001 >=. srepl 0
         &&* ix - (sprimalPart . sfloor) (ssum0 @[5,12,11,9,4] @(TKScalar r)
             $ sreplicate0N @[5,12,11,9,4] (ssum0 v)) - srepl 10001 <=. srepl 1)
        (sindex v ((kfromS $ ix - (sprimalPart . sfloor) (ssum0 @[5,12,11,9,4] @(TKScalar r)
             $ sreplicate0N @[5,12,11,9,4] (ssum0 v)) - srepl 10001) :.$ ZIS ))
           -- index out of bounds; also fine
        (sqrt $ abs $ sindex v ((kfromR $ rfromS $ let rr = (ix - (sprimalPart . sfloor) (ssum0 v) - srepl 10001) `remH` srepl 2
                                in ifH (signum rr ==. negate (signum $ srepl 2))
                                   (rr + srepl 2)
                                   rr) :.$ ZIS))

testFooBuild21S :: Assertion
testFooBuild21S =
  assertEqualUpToEpsilon' 1e-10
    (ringestData [2] [0.2886751345948129,0.35355339059327373])
    (rev' @Double @1 (fooBuild2S @2 @'[] . sfromR) (ringestData [2] [3.0,2.0]))

testFooBuild25S :: Assertion
testFooBuild25S =
  assertEqualUpToEpsilon' 1e-10
    (ringestData [2,2,1,2,2] [0.22360679774997896,0.35355339059327373,0.20412414523193154,0.5,-0.35355339059327373,500.0,1.5811388300841895,-1.118033988749895,0.1381447409988844,0.16666666666666666,0.17677669529663687,-0.25,8.574929257125441e-2,0.288948802391873,-8.703882797784893e-2,9.805806756909202e-2])
    (rev' @Double @5 (fooBuild2S @2 @[2, 1, 2, 2] . sfromR) t16)

fooBuildNest2S
  :: forall k sh target r.
     (ADReady target, GoodScalar r, KnownNat k, Floating (target (TKS sh r)), RealFloat r, KnownShS sh)
  => target (TKS (k : sh) r) -> target (TKR (1 + Rank sh) r)
fooBuildNest2S v = rfromS $
  sbuild1 @2 $ \ix' -> let ix :: PrimalOf target (TKS '[] Int64)
                           ix = sfromR $ rfromK ix' in
    ifH (ix - (sunNest @_ @'[] @'[] . sprimalPart . snest knownShS . sfloor) (ssum0 @[5,12,11,9,4] @(TKScalar r)
             $ sreplicate0N @[5,12,11,9,4] (ssum0 v)) - srepl 10001 >=. srepl 0
         &&* ix - (sprimalPart . sfloor) (ssum0 @[5,12,11,9,4] @(TKScalar r)
             $ sreplicate0N @[5,12,11,9,4] (ssum0 v)) - srepl 10001 <=. srepl 1)
-- TODO:        (sindex v (ShapedList.singletonIndex (ix - (sprimalPart . sfloor) (ssum0 @[5,12,11,9,4] @r  $ sunNest $ treplicate (SNat @5) knownSTK $ snest (knownShS @[12,11])
        (sindex v ((kfromR $ rfromS $ ix - (sprimalPart . sfloor) (ssum0 @[5,12,11,9,4] @(TKScalar r) @target $ sunNest $ tproject2 $ tfromPrimal knownSTK $ tpair tunit (sprimalPart $ snest (knownShS @[5,12,11])
             $ sreplicate0N @[5,12,11,9,4] (ssum0 v))) - srepl 10001) :.$ ZIS))
           -- index out of bounds; also fine
-- TODO:        (sunNest @_ @'[] @sh $ tlet (snest (knownShS @'[]) $ (sfromPrimal ix - sfloor (ssum0 v) - srepl 10001) `remH` srepl 2) $ \rr -> snest (knownShS @'[]) $ sqrt $ abs $ sindex v (ShapedList.singletonIndex (ifH (signum (sprimalPart (sunNest rr)) ==. negate (signum $ srepl 2)) (sprimalPart (sunNest rr) + srepl 2) (sprimalPart (sunNest rr)))))
        (sunNest @_ @'[] @sh $ tlet ((sfromPrimal ix - sfloor (ssum0 v) - srepl 10001) `remH` srepl 2) $ \rr -> snest (knownShS @'[]) $ sqrt $ abs $ sindex v ((kfromS $ ifH (signum (sprimalPart rr) ==. negate (signum $ srepl 2)) (sprimalPart rr + srepl 2) (sprimalPart rr)) :.$ ZIS))

testFooBuildNest21S :: Assertion
testFooBuildNest21S =
  assertEqualUpToEpsilon' 1e-10
    (ringestData [2] [0.2886751345948129,0.35355339059327373])
    (rev' @Double @1 (fooBuildNest2S @2 @'[] . sfromR) (ringestData [2] [3.0,2.0]))

testFooBuildNest25S :: Assertion
testFooBuildNest25S =
  assertEqualUpToEpsilon' 1e-10
    (ringestData [2,2,1,2,2] [0.22360679774997896,0.35355339059327373,0.20412414523193154,0.5,-0.35355339059327373,500.0,1.5811388300841895,-1.118033988749895,0.1381447409988844,0.16666666666666666,0.17677669529663687,-0.25,8.574929257125441e-2,0.288948802391873,-8.703882797784893e-2,9.805806756909202e-2])
    (rev' @Double @5 (fooBuildNest2S @2 @[2, 1, 2, 2] . sfromR) t16)

fooBuild3 :: forall target r n.
             ( ADReady target, GoodScalar r, KnownNat n, RealFloatH (target (TKR n r)) )
          => target (TKR (1 + n) r) -> target (TKR (1 + n) r)
fooBuild3 v =
  rbuild1 22 $ \ix ->
    bar ( rreplicate0N (shrTail $ rshape v) (rscalar 1)
        , rindex v [minH 1 (ix + 1)] )  -- index not out of bounds

testFooBuild3 :: Assertion
testFooBuild3 =
  assertEqualUpToEpsilon' 1e-10
    (ringestData [2,2,1,2,2] [0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,423.72976235076516,-260.41676627885636,-17.60047532855961,151.18955028869385,-1059.9668424433578,-65.00898015327623,-21.49245448729951,743.7622427949768])
    (rev' @Double @5 fooBuild3 t16)

fooBuild5 :: forall target r n.
             ( ADReady target, GoodScalar r, KnownNat n, RealFloatH (target (TKR n r)) )
          => target (TKR (1 + n) r) -> target (TKR (1 + n) r)
fooBuild5 v =
  let r = rsum v
      v' = rreplicate0N (shrTail $ rshape v) $ rminimum $ rflatten v
  in rbuild1 2 $ \ix ->
       r * foo ( rreplicate0N (shrTail $ rshape v) (rscalar 3)
               , rrepl (rshape r) 5 * r
               , r * v')
       + bar (r, rindex v [minH 1 (ix + 1)])  -- index not out of bounds

testFooBuildDt :: Assertion
testFooBuildDt =
  assertEqualUpToEpsilon 1e-5
    (rconcrete $ Nested.rfromListPrimLinear [2,2,1,2,2] [1.1033568028244503e7,74274.22833989389,-5323238.2765011545,253074.03394016018,4.14744804041263e7,242643.98750578283,-1.922371592087736e7,2.730274503834733e7,1.135709425204681e7,6924.195066252549,-5345004.080027547,255679.51406100337,3.8870981856703006e7,241810.92121468345,-1.9380955730171032e7,2.877024321777493e7])
    (vjp @_ @(TKR 5 Double)
           fooBuild5 t16 (rreplicate0N [2, 2, 1, 2, 2] (rscalar 42)))

testFooBuildDt2 :: Assertion
testFooBuildDt2 =
  assertEqualUpToEpsilon 1e-5
    (rconcrete $ Nested.rfromListPrimLinear [2,2,1,2,2] [2.206713605648901e7,148548.45667978778,-1.0646476553002307e7,506148.0678803204,8.294896080825263e7,485287.9750115657,-3.844743184175473e7,5.460549007669466e7,2.271418850409362e7,13848.390132505112,-1.0690008160055092e7,511359.0281220066,7.774196371340603e7,483621.8424293669,-3.876191146034207e7,5.754048643554987e7])
    (vjp @_ @(TKProduct (TKR 5 Double) (TKR 5 Double))
           (\x -> let y = fooBuild5 x in tpair y y) t16 (let dt = rreplicate0N [2, 2, 1, 2, 2] (rscalar 42) in tpair dt dt))

testFooBuild5 :: Assertion
testFooBuild5 =
  assertEqualUpToEpsilon' 1e-5
    (ringestData [3,1,2,2,1,2,2] [-613291.6547530327,571164.2201603781,-1338602.6247083102,528876.2566682736,1699442.2143691683,2874891.369778316,-3456754.605470273,3239487.8744244366,554916.1344235454,-775449.1803684114,3072.200583200206,1165767.8436804386,-1.0686356667942494e7,-6606976.194539241,-6457671.748790982,4791868.42112978,-615556.7946425928,569660.3506343022,-1348678.1169100606,534886.9366492515,1696036.143341285,2883992.9672165257,-3456212.5353846983,3240296.690514803,629047.8398075115,-794389.5797803313,-1143.8025173051583,1177448.8083517442,-1.15145721735623e7,-6618648.839812404,-6462386.031613377,5358224.852822481,-613291.6547530327,571164.2201603781,-1338602.6247083102,528876.2566682736,1699442.2143691683,2874891.369778316,-3456754.605470273,3239487.8744244366,554916.1344235454,-775449.1803684114,3072.200583200206,1165767.8436804386,-1.0686356667942494e7,-6606976.194539241,-6457671.748790982,4791868.42112978])
    (rev' @Double @7 fooBuild5 t48)

fooBuild1 :: forall target r n.
             ( ADReady target, GoodScalar r, KnownNat n, RealFloatH (target (TKR n r)) )
          => target (TKR (1 + n) r) -> target (TKR (1 + n) r)
fooBuild1 v =
  let r = rsum v
      tk = rreplicate0N (shrTail $ rshape v)
      v' = tk $ rminimum $ rflatten v
  in rbuild1 3 $ \ix ->
       r * foo ( tk (rscalar 3)
               , tk (rscalar 5) * r
               , r * v')
       + bar (r, rindex v [minH 1 (ix + 1)])

testFooBuild1 :: Assertion
testFooBuild1 =
  assertEqualUpToEpsilon' 1e-8
    (ringestData [2,2,1,2,2] [394056.00100873224,2652.651012139068,-190115.65273218407,9038.358355005721,1481231.4430045108,8665.8566966351,-686561.2828884773,975098.0370838332,405610.50900167174,247.29268093759174,-190893.00285812665,9131.411216464405,1388249.3520251075,8636.104329095837,-692176.9903632513,1027508.6863491047])
    (rev' @Double @5 fooBuild1 t16)

fooMap1 :: (ADReady target, GoodScalar r, KnownNat n, Differentiable r)
        => IShR (1 + n) -> target (TKR 0 r) -> target (TKR (1 + n) r)
fooMap1 sh r =
  let v = fooBuild1 $ rreplicate0N sh (r * r)
  in rmap0N (\x -> x * r + rscalar 5) v

testFooMap :: Assertion
testFooMap =
  assertEqualUpToEpsilon' 1e-3
    (rscalar 2.7518227)
    (rev' @Float @1 (fooMap1 [130]) (rscalar 0.1))

-- Reduced test, because this takes forever with Ast but without vectorization.
testFooMap1 :: Assertion
testFooMap1 =
  assertEqualUpToEpsilon 1e-6
    (rscalar 3901.312463734578)
    (grad (kfromR @_ @Double . rsum0 @7 . fooMap1 [4, 3, 2, 3, 4, 5, 3]) (rscalar 0.1))

fooNoGo :: forall target r n.
           ( ADReady target, GoodScalar r, KnownNat n, Differentiable r )
        => target (TKR (1 + n) r) -> target (TKR (1 + n) r)
fooNoGo v =
  let r = rsum v
      r0 = rsum0 v
      shTail = shrTail (rshape v)
  in rbuild1 3 (\ix ->
       bar ( rreplicate0N shTail (rscalar 3.14)
           , bar ( rrepl shTail 3.14
                 , rindex v [ix]) )
       + ifH (rindex v (ix * 2 :.: ZIR) <=. rreplicate0N shTail (rscalar 0) &&* 6 >. abs ix)
               r (rreplicate0N shTail (rscalar 5) * r))
     / rslice 1 3 (rmap0N (\x -> ifH (x >. r0) r0 x) v)
     * rbuild1 3 (const $ rrepl shTail 1)

testFooNoGo :: Assertion
testFooNoGo =
  assertEqualUpToEpsilon' 1e-6
   (ringestData [5] [344.3405885672822,-396.1811403813819,7.735358041386672,-0.8403418295960372,5.037878787878787])
   (rev' @Double @1 fooNoGo
         (ringestData [5] [1.1 :: Double, 2.2, 3.3, 4, 5]))

testFooNoGo10 :: Assertion
testFooNoGo10 =
  assertEqualUpToEpsilon 1e-10
    (ringestData [5, 3, 1, 2, 2, 1, 2, 2] [8.096867407436072e-8,9.973025492756426e-8,9.976696178938985e-8,5.614458707681111e-8,-1.8338500573636686e-7,-2.144970334428336e-7,7.354143606421902e-7,-1.8140041785503643e-7,8.096867407436072e-8,9.973025492756426e-8,9.976696178938985e-8,5.614458707681111e-8,-2.01381292700262e-7,-2.221588091014473e-7,7.354143606421902e-7,-1.9951065225263367e-7,1.7230532848112822e-7,4.5426218104870796e-7,1.430886696893587e-7,9.354993295163118e-7,-5.225515010723883e-7,1.019433073376504e-6,9.64067025472343e-6,-4.872227980305747e-6,8.089200625992941e-8,9.924319994964371e-8,1.092480101004153e-7,-2.8478802468285825e-7,9.641049518625974e-8,2.9624147815716037e-7,-1.950868158558337e-7,9.547754822865364e-8,4.5426218104870796e-7,4.5426218104870796e-7,4.5426218104870796e-7,4.5426218104870796e-7,-4.872227980305747e-6,-4.872227980305747e-6,-4.872227980305747e-6,-4.872227980305747e-6,9.361277121832246e-8,-4.872227980305747e-6,-4.872227980305747e-6,-4.872227980305747e-6,9.361277121832246e-8,9.361277121832246e-8,9.361277121832246e-8,9.361277121832246e-8,-5.488572216677945e-7,-1.8496203182958057e-7,-1.4603644180845103e-7,-1.2145268106051633e-7,-2.817402689957553e-7,-2.9913537180597976e-7,6.272804203945257e-7,-2.3697344464172694e-7,-5.488572216677945e-7,-1.8496203182958057e-7,-1.4603644180845103e-7,-1.2145268106051633e-7,-2.613973017956691e-7,-3.0013408634207794e-7,6.272804203945257e-7,-2.916736028401805e-7,-7.0114505846358575e-6,-4.303381366239431e-5,-4.897282418246382e-6,-1.710952247892854e-4,-4.2040039667393255e-5,-2.0204742564752248e-4,-1.7017980671040968e-2,-4.247008401789142e-3,-1.056090348050961e-6,-2.210187184450231e-6,-2.7842041329045203e-6,-1.0402806498987974e-5,-1.2967382896879757e-7,-1.9315601705070884e-5,-2.40087090725031e-7,-2.4419692405172046e-7,-4.303381366239431e-5,-4.303381366239431e-5,-4.303381366239431e-5,-4.303381366239431e-5,-4.247008401789142e-3,-4.247008401789142e-3,-4.247008401789142e-3,-4.247008401789142e-3,-2.683138631810477e-7,-4.247008401789142e-3,-4.247008401789142e-3,-4.247008401789142e-3,-2.683138631810477e-7,-2.683138631810477e-7,-2.683138631810477e-7,-2.683138631810477e-7,-5.488572216677945e-7,-1.8496203182958057e-7,-1.4603644180845103e-7,-1.2145268106051633e-7,-2.817402689957553e-7,-2.9913537180597976e-7,6.272804203945257e-7,-2.3697344464172694e-7,-5.488572216677945e-7,-1.8496203182958057e-7,-1.4603644180845103e-7,-1.2145268106051633e-7,-2.613973017956691e-7,-3.0013408634207794e-7,6.272804203945257e-7,-2.916736028401805e-7,-7.0114505846358575e-6,-4.303381366239431e-5,-4.897282418246382e-6,-1.710952247892854e-4,-4.2040039667393255e-5,-2.0204742564752248e-4,-1.7017980671040968e-2,-4.247008401789142e-3,-1.056090348050961e-6,-2.210187184450231e-6,-2.7842041329045203e-6,-1.0402806498987974e-5,-1.2967382896879757e-7,-1.9315601705070884e-5,-2.40087090725031e-7,-2.4419692405172046e-7,-4.303381366239431e-5,-4.303381366239431e-5,-4.303381366239431e-5,-4.303381366239431e-5,-4.247008401789142e-3,-4.247008401789142e-3,-4.247008401789142e-3,-4.247008401789142e-3,-2.683138631810477e-7,-4.247008401789142e-3,-4.247008401789142e-3,-4.247008401789142e-3,-2.683138631810477e-7,-2.683138631810477e-7,-2.683138631810477e-7,-2.683138631810477e-7,-5.469529675653596e-7,-2.331458950045675e-7,-1.9907443163522408e-7,-1.4019078434680374e-7,-6.95091094132346e-8,-5.685763846730528e-8,-9.268594848659335e-8,-3.010367762029461e-8,-5.469529675653596e-7,-2.331458950045675e-7,-1.9907443163522408e-7,-1.4019078434680374e-7,-3.415394012988984e-8,-5.069973314807702e-8,-9.268594848659335e-8,-6.380451815099858e-8,-6.883755913116986e-6,-4.273807584344302e-5,-4.79037108793574e-6,-1.705307241188017e-4,-4.2267488166320864e-5,-2.0143642393829028e-4,-1.701262134129569e-2,-4.2496361738088365e-3,-1.0224785375169973e-6,-2.1427637177332083e-6,-2.705952143004936e-6,-1.0493018474305117e-5,-1.819666770962338e-7,-1.911089472080586e-5,-9.045482032374276e-8,-2.819821645880664e-7,-4.273807584344302e-5,-4.273807584344302e-5,-4.273807584344302e-5,-4.273807584344302e-5,-4.2496361738088365e-3,-4.2496361738088365e-3,-4.2496361738088365e-3,-4.2496361738088365e-3,-3.019273543907303e-7,-4.2496361738088365e-3,-4.2496361738088365e-3,-4.2496361738088365e-3,-3.019273543907303e-7,-3.019273543907303e-7,-3.019273543907303e-7,-3.019273543907303e-7,8.287292817679557e-8,5.154639175257732e-8,4.672897196261682e-8,3.740648379052369e-8,2.884615384615385e-8,2.7780699895840894e-8,1.5447991761071065e-8,2.546934916639589e-8,8.287292817679557e-8,5.154639175257732e-8,4.672897196261682e-8,3.740648379052369e-8,2.5862068965517245e-8,2.7275544092553562e-8,1.5447991761071065e-8,2.8358432436548274e-8,3.0000000000000004e-7,7.500000000000001e-7,2.5000000000000004e-7,1.5000000000000002e-6,-7.500000000000001e-7,1.6304347826086957e-6,1.5000000000000002e-5,-7.500000000000001e-6,1.1450381679389314e-7,1.6666666666666668e-7,1.8750000000000003e-7,-3.7500000000000006e-7,4.411764705882353e-8,5.00948462422186e-7,-4.545454545454546e-8,5.76923076923077e-8,7.500000000000001e-7,7.500000000000001e-7,7.500000000000001e-7,7.500000000000001e-7,-7.500000000000001e-6,-7.500000000000001e-6,-7.500000000000001e-6,-7.500000000000001e-6,5.999928000863991e-8,-7.500000000000001e-6,-7.500000000000001e-6,-7.500000000000001e-6,5.999928000863991e-8,5.999928000863991e-8,5.999928000863991e-8,5.999928000863991e-8])
   (grad (kfromR @_ @Double . rsum0 @8 . rmap0N (* rscalar 0.000000001) . fooNoGo) (rmap0N (* rscalar 0.01) $ rreplicate 5 t48))

nestedBuildMap :: forall target n r.
                  (ADReady target, GoodScalar r, n <= 6, KnownNat n, Differentiable r)
               => target (TKR 0 r) -> target (TKR (1 + n) r)
nestedBuildMap r =
  let w x = rreplicate0N [4] x :: target (TKR 1 r)
      v' = rreplicate0N (177 :$: ZSR) r
      nestedMap x = rmap0N (x /) (w x)
      variableLengthBuild iy = rbuild1 7 (\ix ->
        rindex v' (ix + iy :.: ZIR))
      doublyBuild =
        rbuild1 3 (rreplicate0N (shrTake @n @(6 - n)
                             $ 2 :$: 4 :$: 2 :$: 1 :$: 3 :$: 2 :$: ZSR)
                   . rminimum . variableLengthBuild)
  in rmap0N (\x -> x * rsum0
                         (rbuild1 3 (\ix -> bar (x, rindex v' [ix]))
                          + fooBuild1 (nestedMap x)
                          / fooMap1 [3] x)
            ) doublyBuild

testNestedBuildMap1 :: Assertion
testNestedBuildMap1 =
  assertEqualUpToEpsilon' 1e-8
    (rscalar 22.673212907588812)
    (rev' @Double @1 nestedBuildMap (rscalar 0.6))

testNestedBuildMap10 :: Assertion
testNestedBuildMap10 =
  assertEqualUpToEpsilon 1e-8
    (map rscalar [109.62086996459126,106.70290239773645,103.05843225947055,98.11825678264942,67.8014491889543,22.67321290758882,-163.40832575807545,376.4240286600336,-1996.9068313949347,249.28292226561257, 109.62086996459126,106.70290239773645,103.05843225947055,98.11825678264942,67.8014491889543,22.67321290758882,-163.40832575807545,376.4240286600336,-1996.9068313949347,249.28292226561257, 109.62086996459126,106.70290239773645,103.05843225947055,98.11825678264942,67.8014491889543,22.67321290758882,-163.40832575807545,376.4240286600336,-1996.9068313949347,249.28292226561257])
    (map (cgrad (kfromR . rsum0 @1 @(TKScalar Double) . nestedBuildMap))
         (map (Concrete . Nested.rscalar) $ [0.1, 0.2 .. 1] ++ [0.1, 0.2 .. 1] ++ [0.1, 0.2 .. 1]))

testNestedBuildMap11 :: Assertion
testNestedBuildMap11 =
  assertEqualUpToEpsilon 1e-8
    (map rscalar [109.62086996459126,106.70290239773645,103.05843225947055,98.11825678264942,67.8014491889543,22.67321290758882,-163.40832575807545,376.4240286600336,-1996.9068313949347,249.28292226561257, 109.62086996459126,106.70290239773645,103.05843225947055,98.11825678264942,67.8014491889543,22.67321290758882,-163.40832575807545,376.4240286600336,-1996.9068313949347,249.28292226561257, 109.62086996459126,106.70290239773645,103.05843225947055,98.11825678264942,67.8014491889543,22.67321290758882,-163.40832575807545,376.4240286600336,-1996.9068313949347,249.28292226561257])
    (map (grad (kfromR . rsum0 @1 @(TKScalar Double) . nestedBuildMap))
         (map (Concrete . Nested.rscalar) $ [0.1, 0.2 .. 1] ++ [0.1, 0.2 .. 1] ++ [0.1, 0.2 .. 1]))

{-
testNestedBuildMap7 :: Assertion
testNestedBuildMap7 =
  assertEqualUpToEpsilon' 1e-8
    (rscalar 2176.628439128524)
    (rev' @Double @7 nestedBuildMap (rscalar 0.6))
-}

-- The n <= 4 is necessary despite what GHC claims. Applying @(2 + n)
-- to nestedBuildMap doesn't help.
nestedSumBuild
  :: forall target n r.
     (ADReady target, GoodScalar r, n <= 4, KnownNat n, Differentiable r)
  => target (TKR n r) -> target (TKR (2 + n) r)
nestedSumBuild v =
  rbuild1 13 $ \ix1 -> rbuild1 4 $ \ix2 ->
    ifH (ix2 >. ix1)
        (rmap0N ((* rscalar (-0.00000003)) . sqrt . abs)
         $ nestedBuildMap (rsum0 v)
           `rindex` (ix2 `remH` 3 :.: minH 1 ix1 :.: minH ix1 3 :.: ZIR))
        (nestedBuildMap (rscalar 0.00042)
         `rindex` (ix2 `remH` 3 :.: minH 1 ix1 :.: minH ix1 3 :.: ZIR))

testNestedSumBuild1 :: Assertion
testNestedSumBuild1 =
  assertEqualUpToEpsilon 1e-6
    (ringestData [5] [5.738943380972744e-6,5.738943380972744e-6,5.738943380972744e-6,5.738943380972744e-6,5.738943380972744e-6])
    (grad (kfromR . rsum0 @3 @(TKScalar Double) . nestedSumBuild) (ringestData [5] [1.1, 2.2, 3.3, 4, -5.22]))

{-
testNestedSumBuild5 :: Assertion
testNestedSumBuild5 =
  assertEqualUpToEpsilon' 1e-6
    (ringestData [1,2,2] [3.5330436757054903e-3,3.5330436757054903e-3,3.5330436757054903e-3,3.5330436757054903e-3])
    (rev' @Double @5 nestedSumBuild (rsum (rsum t16)))
-}

nestedSumBuildB :: forall target n r. (ADReady target, GoodScalar r, KnownNat n)
                => target (TKR (1 + n) r) -> target (TKR 3 r)
nestedSumBuildB v =
  rbuild @2 [13, 4, 2] $ \case
    [ix, ix2] ->
      flip rindex [ix2]
        (rfromList
             [ rbuild1 2 rfromIndex0
             , rsum $ rbuild [9, 2] $ const $ rfromIndex0 ix
             , rindex v (fromList
                         $ replicate (rlength v - 1)
                             (maxH 0 $ minH 1 $ ix2 `quotH` 2 + ix `quotH` 4 - 1))
             , rbuild1 2 (\_ -> rsum0 v)
             , rsum (rbuild1 7 (\ix7 ->
                 rreplicate 2 (rfromIndex0 ix7)))
             ])
    _ -> error "nestedSumBuildB: impossible pattern needlessly required"

testNestedSumBuildB :: Assertion
testNestedSumBuildB =
  assertEqualUpToEpsilon' 1e-8
    (ringestData [2,3,2,2,2] [30.0,30.0,26.0,26.0,26.0,26.0,26.0,26.0,26.0,26.0,26.0,26.0,26.0,26.0,26.0,26.0,26.0,26.0,26.0,26.0,26.0,26.0,26.0,26.0,26.0,26.0,26.0,26.0,26.0,26.0,26.0,26.0,26.0,26.0,26.0,26.0,26.0,26.0,35.0,35.0,26.0,26.0,26.0,26.0,26.0,26.0,26.0,26.0])
    (rev' @Double @3 nestedSumBuildB (rsum $ rsum $ rtranspose [1, 4, 2, 0, 3] t48))

nestedBuildIndex :: forall target r. (ADReady target, GoodScalar r)
                 => target (TKR 5 r) -> target (TKR 3 r)
nestedBuildIndex v =
  rbuild1 2 $ \ix2 -> rindex (rbuild1 3 $ \ix3 -> rindex (rbuild1 3 $ \ix4 -> rindex v (ix4 `remH` 2 :.: ix2 :.: 0 :.: ZIR)) [ix3]) (ix2 :.: ZIR)

testNestedBuildIndex :: Assertion
testNestedBuildIndex =
  assertEqualUpToEpsilon' 1e-10
    (ringestData [2,2,1,2,2] [1.0,1.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,1.0,1.0])
    (rev' @Double @3 nestedBuildIndex t16)

barRelu
  :: ( ADReady target, GoodScalar r, KnownNat n, Differentiable r )
  => target (TKR n r) -> target (TKR n r)
barRelu x = let t = rreplicate0N (rshape x) (rscalar 0.001) * x
            in relu $ bar (t, relu t)

testBarReluADValDt :: Assertion
testBarReluADValDt =
  assertEqualUpToEpsilon 1e-6
    (rconcrete $ Nested.rfromListPrimLinear [2,2,1,2,2] [1.2916050471365906e-2,1.2469757606504572e-2,1.3064120086501589e-2,1.2320300700062944e-2,0.0,1.217049789428711e-2,1.2185494267265312e-2,0.0,1.4105363649830907e-2,1.3506236503127638e-2,1.3359213691150671e-2,0.0,1.7066665416485535e-2,1.2618022646204737e-2,0.0,1.595161947206668e-2])
    (vjp @_ @(TKR 5 Double)
           barRelu t16 (rreplicate0N [ 2 , 2 , 1 , 2 , 2 ] (rscalar 42.2)))

testBarReluADValDt2 :: Assertion
testBarReluADValDt2 =
  assertEqualUpToEpsilon 1e-6
    (rconcrete $ Nested.rfromListPrimLinear [2,2,1,2,2] [84.42583210117625,84.42493951543845,84.4261282404092,84.42464060162287,84.4,84.42434099465609,84.4243709887547,84.4,84.42821072755468,84.42701247325044,84.42671842762383,84.4,84.43413333114152,84.42523604552053,84.4,84.43190323923253])
    (vjp @_ @(TKProduct
                  (TKR 4 Double)
                  (TKProduct (TKR 5 Float)
                             (TKS [2,2,1,2,2] Double)))
           (\x -> tpair (rsum x) (tpair (rcast $ barRelu x) (sfromR $ barRelu x))) t16
                  (let dt = rreplicate0N [ 2 , 2 , 1 , 2 , 2 ] (rscalar 42.2)
                   in tpair (rsum dt) (tpair (rcast dt) (sfromR dt))))

testBarReluADVal :: Assertion
testBarReluADVal =
  assertEqualUpToEpsilon' 1e-10
    (ringestData [3,1,2,2,1,2,2] [3.513740871835189e-4,3.8830416352632824e-4,3.981974371104471e-4,4.2420226755643853e-4,4.6186212581292275e-4,4.6805323209889415e-4,5.933633926875981e-4,4.8311739820100107e-4,3.513740871835189e-4,3.8830416352632824e-4,3.981974371104471e-4,4.2420226755643853e-4,4.803836032226148e-4,4.7114455958615145e-4,5.933633926875981e-4,4.6464270870595213e-4,3.060675467148428e-4,2.954918864100193e-4,3.095763053673437e-4,2.9195025355591045e-4,0.0,2.9166656928452994e-4,2.887557883241243e-4,0.0,3.342503234557057e-4,3.2005299770444394e-4,3.165690448140097e-4,0.0,4.0442335110155446e-4,2.990052759764126e-4,0.0,3.780004614233832e-4,2.954918864100193e-4,2.954918864100193e-4,2.954918864100193e-4,2.954918864100193e-4,0.0,0.0,0.0,0.0,3.7466025157760897e-4,0.0,0.0,0.0,3.7466025157760897e-4,3.7466025157760897e-4,3.7466025157760897e-4,3.7466025157760897e-4])
    (rev' @Double @7 barRelu t48)

testBarReluADVal3 :: Assertion
testBarReluADVal3 =
  assertEqualUpToEpsilon' 1e-10
    (ringestData [3,1,2,2,1,2,2] [2.8846476339094805e-4,2.885038541771792e-4,2.885145151321922e-4,2.8854294397024206e-4,2.885852309100301e-4,2.885923176600045e-4,2.887454843457817e-4,2.886097295122454e-4,2.8846476339094805e-4,2.885038541771792e-4,2.885145151321922e-4,2.8854294397024206e-4,2.8860655161315664e-4,2.88595871110374e-4,2.887454843457817e-4,2.885884088500461e-4,2.884182085399516e-4,2.884075468755327e-4,2.8842176240868867e-4,2.8840399312321096e-4,0.0,2.8840370860416445e-4,2.884007943794131e-4,0.0,2.884469945274759e-4,2.8843242392031246e-4,2.884288700806792e-4,0.0,2.885212670262263e-4,2.884110805753153e-4,0.0,2.8849283778617973e-4,2.884075468755327e-4,2.884075468755327e-4,2.884075468755327e-4,2.884075468755327e-4,0.0,0.0,0.0,0.0,2.884892851579934e-4,0.0,0.0,0.0,2.884892851579934e-4,2.884892851579934e-4,2.884892851579934e-4,2.884892851579934e-4])
    (rev' @Double @7 barRelu
         (rmap0N (* rscalar 0.001) t48))

braidedBuilds :: forall target n r. (ADReady target, GoodScalar r, KnownNat n, Differentiable r)
              => target (TKR (1 + n) r) -> target (TKR 2 r)
braidedBuilds r =
  rbuild1 3 (\ix1 ->
    rbuild1 4 (\ix2 -> rindex (rfromList
      [rfromIndex0 ix2, rscalar 7, rsum0 (rslice 1 1 r), rscalar (-0.2)]) (ix1 :.: ZIR)))

testBraidedBuilds :: Assertion
testBraidedBuilds =
  assertEqualUpToEpsilon' 1e-10
    (ringestData [4] [0.0,4.0,0.0,0.0])
    (rev' @Double @2 (braidedBuilds @_ @0) (rreplicate0N [4] (rscalar 3.4)))

testBraidedBuilds1 :: Assertion
testBraidedBuilds1 =
  assertEqualUpToEpsilon' 1e-10
    (ringestData [2,2,1,2,2] [0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,4.0,4.0,4.0,4.0,4.0,4.0,4.0,4.0])
    (rev' @Double @2 braidedBuilds t16)

recycled :: (ADReady target, GoodScalar r, KnownNat n)
         => target (TKR n r) -> target (TKR 7 r)
recycled r =
  rbuild1 2 $ \_ -> rbuild1 4 $ \_ -> rbuild1 2 $ \_ -> rbuild1 3 $ \_ ->
    nestedSumBuildB (rreplicate 4 r)

testRecycled :: Assertion
testRecycled =
  assertEqualUpToEpsilon' 1e-6
    (rrepl [2] 5616)
    (rev' @Double @7 (recycled @_ @_ @1) (rreplicate0N [2] (rscalar 1.0001)))

{-
testRecycled1 :: Assertion
testRecycled1 =
  assertEqualUpToEpsilon' 1e-6
    (ringestData [5, 4, 2] [5184.0,5184.0,4992.0,4992.0,4992.0,4992.0,4992.0,4992.0,4992.0,4992.0,5424.0,5424.0,4992.0,4992.0,4992.0,4992.0,4992.0,4992.0,4992.0,4992.0,4992.0,4992.0,4992.0,4992.0,4992.0,4992.0,4992.0,4992.0,4992.0,4992.0,4992.0,4992.0,4992.0,4992.0,4992.0,4992.0,4992.0,4992.0,4992.0,4992.0])
    (rev' @Double @7 (recycled @_ @_ @3) (rreplicate0N [5, 4, 2] (rscalar 0.0002)))
-}

concatBuild :: forall target r n.
               (ADReady target, GoodScalar r, KnownNat n, Differentiable r)
            => target (TKR (1 + n) r) -> target (TKR (3 + n) r)
concatBuild r =
  rbuild1 7 (\i ->
    rconcat [ rbuild1 5 (const r)
            , rbuild1 1 (\j -> rmap0N (* rfromIndex0 (j - i)) r)
            , rbuild1 11 (\j ->
                rmap0N (* (rfromIndex0
                  (kfromR (rprimalPart @target (rscalar 125)) * (j `remH` (abs (signum i + abs i) + 1))
                   + maxH j (i `quotH` (j + 1)) * (kfromR . rprimalPart . rfloor) (rsum0 r)
                   - ifH (r <=. r &&* i <. j)
                         (kfromR $ rprimalPart $ rminIndex (rflatten r))
                         ((kfromR . rprimalPart . rfloor) $ rsum0 $ r ! ((i * j) `remH` 7 :.: ZIR))))) r)
            , rbuild1 13 (\_k ->
                rsum $ rtr $ rreplicate (rwidth r) (rslice 0 1 r)) ])

testConcatBuild0 :: Assertion
testConcatBuild0 =
  assertEqualUpToEpsilon' 1e-10
    (ringestData [7] [16917.0,16280.0,16280.0,16280.0,16280.0,16280.0,16280.0])
    (rev' @Double @3 concatBuild
       (ringestData [7] [0.651,0.14,0.3414,-0.14,0.0014,0.0020014,0.9999]))

testConcatBuild1 :: Assertion
testConcatBuild1 =
  assertEqualUpToEpsilon 1e-10
    (ringestData [3,1,2,2,1,2,2] [1.4816999999999999e-3,1.4816999999999999e-3,1.4816999999999999e-3,1.4816999999999999e-3,1.4816999999999999e-3,1.4816999999999999e-3,1.4816999999999999e-3,1.4816999999999999e-3,1.4816999999999999e-3,1.4816999999999999e-3,1.4816999999999999e-3,1.4816999999999999e-3,1.4816999999999999e-3,1.4816999999999999e-3,1.4816999999999999e-3,1.4816999999999999e-3,1.4544e-3,1.4544e-3,1.4544e-3,1.4544e-3,1.4544e-3,1.4544e-3,1.4544e-3,1.4544e-3,1.4544e-3,1.4544e-3,1.4544e-3,1.4544e-3,1.4544e-3,1.4544e-3,1.4544e-3,1.4544e-3,1.4544e-3,1.4544e-3,1.4544e-3,1.4544e-3,1.4544e-3,1.4544e-3,1.4544e-3,1.4544e-3,1.4544e-3,1.4544e-3,1.4544e-3,1.4544e-3,1.4544e-3,1.4544e-3,1.4544e-3,1.4544e-3])
    (grad (kfromR . rsum0 @9 @(TKScalar Double) . concatBuild . rmap0N (* rscalar 1e-7)) t48)

concatBuildm :: forall target r n.
                (ADReady target, GoodScalar r, KnownNat n, Differentiable r)
             => target (TKR (1 + n) r) -> target (TKR (2 + n) r)
concatBuildm r =
  rbuild1 7 (\i ->
    rmap0N (* (rfromIndex0
      ((kfromR . rprimalPart . rfloor) $ rsum0 $ r ! (i :.: ZIR)))) r)

testConcatBuild0m :: Assertion
testConcatBuild0m =
  assertEqualUpToEpsilon' 1e-10
    (ringestData [7] [-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0])
    (rev' @Double @2 concatBuildm
       (ringestData [7] [0.651,0.14,0.3414,-0.14,0.0014,0.0020014,0.9999]))

testConcatBuild1m :: Assertion
testConcatBuild1m =
  assertEqualUpToEpsilon' 1e-10
    (ringestData [3,1,2,2,1,2,2] [0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0])
    (rev' @Double @8 (concatBuildm . rmap0N (* rscalar 1e-7)) t48)

concatBuild2 :: (ADReady target, GoodScalar r, KnownNat n)
             => target (TKR (1 + n) r) -> target (TKR (3 + n) r)
concatBuild2 r =
  rbuild1 5 (\i ->
    rbuild1 2 (\j -> rmap0N (* rfromIndex0 (maxH j (i `quotH` (j + 1)))) r))

testConcatBuild2 :: Assertion
testConcatBuild2 =
  assertEqualUpToEpsilon' 1e-10
    (ringestData [3] [16.0,16.0,16.0])
    (rev' @Double @3 concatBuild2 (ringestData [3] [0.651,0.14,0.3414]))

testConcatBuild22 :: Assertion
testConcatBuild22 =
  assertEqualUpToEpsilon' 1e-10
    (ringestData [3,1,2,2,1,2,2] [16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0])
    (rev' @Double @9 concatBuild2 t48)

concatBuild3 :: (ADReady target, GoodScalar r)
             => target (TKR 1 r) -> target (TKR 2 r)
concatBuild3 _r =
  rbuild1 5 (\i ->
    rbuild1 2 (\j -> rfromIndex0 (maxH j (i `quotH` (j + 1)))))

testConcatBuild3 :: Assertion
testConcatBuild3 =
  assertEqualUpToEpsilon' 1e-10
    (ringestData [0] [])
    (rev' @Double @2 concatBuild3 (ringestData [0] []))

testLogistic0 :: Assertion
testLogistic0 =
  assertEqualUpToEpsilon' 1e-10
    (rscalar 4.5176659730912e-2)
    (rev' @Double @0 logistic (rscalar 3))

testLogistic5 :: Assertion
testLogistic5 =
  assertEqualUpToEpsilon' 1e-10
    (rfromListLinear [2,2,1,2,2] [6.648056670790033e-3,0.10499358540350662,2.466509291359931e-3,0.19661193324148185,0.1049935854035065,0.2499999999999375,0.24937604019289197,0.24751657271185995,2.0452222584760427e-6,1.2337934976493025e-4,3.3523767075636815e-4,1.7662706213291118e-2,1.7763568394002473e-15,4.540945566439111e-2,4.6588861451033536e-15,5.109024314693943e-12])
    (rev' @Double @5 logistic t16)

testLogistic52 :: Assertion
testLogistic52 =
  assertEqualUpToEpsilon' 1e-10
    (rfromListLinear [2,2,1,2,2] [1.3111246391159124e-3,2.1750075272657612e-2,4.8549901151740267e-4,4.312916016242333e-2,2.6155373652699744e-2,5.8750924453083504e-2,5.8238333583278255e-2,5.8847120842749026e-2,4.0211548220008027e-7,2.425923569564766e-5,6.592194028602285e-5,4.415319450597324e-3,3.4925295232121135e-16,9.12284655835676e-3,1.1647215362758384e-15,1.0044951474920845e-12])
    (rev' @Double @5 (logistic . logistic) t16)

logisticOld :: forall target r n.
            ( BaseTensor target, LetTensor target
            , BaseTensor (PrimalOf target), KnownNat n, GoodScalar r
            , Floating (PrimalOf target (TKR n r)) )
         => target (TKR n r) -> target (TKR n r)
logisticOld d0 = tlet d0 $ \d ->  -- used in rprimalPart and in tdualPart
  let sh = rshape d
      y0 = recip (rrepl sh 1 + exp (- rprimalPart @target d))
  in tlet (rfromPrimal @target y0)
     $ \y1 -> let y = rprimalPart @target y1
              in tD knownSTK y (rScale @target (y * (rrepl sh 1 - y))
                                $ rdualPart @target d)

testLogistic0Old :: Assertion
testLogistic0Old =
  assertEqualUpToEpsilon' 1e-10
    (rscalar 4.5176659730912e-2)
    (rev' @Double @0 logisticOld (rscalar 3))

testLogistic5Old :: Assertion
testLogistic5Old =
  assertEqualUpToEpsilon' 1e-10
    (rfromListLinear [2,2,1,2,2] [6.648056670790033e-3,0.10499358540350662,2.466509291359931e-3,0.19661193324148185,0.1049935854035065,0.2499999999999375,0.24937604019289197,0.24751657271185995,2.0452222584760427e-6,1.2337934976493025e-4,3.3523767075636815e-4,1.7662706213291118e-2,1.7763568394002473e-15,4.540945566439111e-2,4.6588861451033536e-15,5.109024314693943e-12])
    (rev' @Double @5 logisticOld t16)

testLogistic52Old :: Assertion
testLogistic52Old =
  assertEqualUpToEpsilon' 1e-10
    (rfromListLinear [2,2,1,2,2] [1.3111246391159124e-3,2.1750075272657612e-2,4.8549901151740267e-4,4.312916016242333e-2,2.6155373652699744e-2,5.8750924453083504e-2,5.8238333583278255e-2,5.8847120842749026e-2,4.0211548220008027e-7,2.425923569564766e-5,6.592194028602285e-5,4.415319450597324e-3,3.4925295232121135e-16,9.12284655835676e-3,1.1647215362758384e-15,1.0044951474920845e-12])
    (rev' @Double @5 (logisticOld . logistic) t16)

logisticA :: forall target r n.
            ( BaseTensor target, LetTensor target
            , BaseTensor (PrimalOf target), KnownNat n, GoodScalar r
            , Floating (PrimalOf target (TKR n r)) )
         => target (TKR n r) -> target (TKR n r)
logisticA d0 = tlet d0 $ \d ->  -- used in rprimalPart and in tdualPart
  let sh = rshape d
      y0 = recip (rrepl sh 1 + exp (- rprimalPart @target d))
  in tlet (rfromPrimal @target y0)
     $ \y1 -> let y = rprimalPart @target y1
              in rfromPrimal y
                 + rfromDual (rScale @target (y * (rrepl sh 1 - y))
                              $ rdualPart @target d)

testLogisticA0 :: Assertion
testLogisticA0 =
  assertEqualUpToEpsilon' 1e-10
    (rscalar 4.5176659730912e-2)
    (rev' @Double @0 logisticA (rscalar 3))

logisticB :: forall target r n.
            ( BaseTensor target, LetTensor target
            , BaseTensor (PrimalOf target), KnownNat n, GoodScalar r
            , Floating (PrimalOf target (TKR n r)) )
         => target (TKR n r) -> target (TKR n r)
logisticB d0 = tlet d0 $ \d ->  -- used in rprimalPart and in tdualPart
  let sh = rshape d
      y0 = recip (rrepl sh 1 + exp (- rprimalPart @target d))
  in tlet (rfromPrimal @target y0)
     $ \y1 -> let y = rprimalPart @target y1
              in rfromPrimal y + rfromDual (rdualPart @target d)

testLogisticB0 :: Assertion
testLogisticB0 =
  assertEqualUpToEpsilon' 1e-10
    (rscalar 1)
    (rev' @Double @0 logisticB (rscalar 3))

logisticC :: forall target r n.
            ( BaseTensor target, LetTensor target
            , KnownNat n, GoodScalar r )
         => target (TKR n r) -> target (TKR n r)
logisticC d0 = tlet d0 $ \d ->  -- used in rprimalPart and in tdualPart
  let y0 = rprimalPart @target d
  in rfromPrimal @target y0

testLogisticC0 :: Assertion
testLogisticC0 =
  assertEqualUpToEpsilon' 1e-10
    (rscalar 0)
    (rev' @Double @0 logisticC (rscalar 3))
