{-# LANGUAGE AllowAmbiguousTypes, OverloadedLists #-}
{-# OPTIONS_GHC -fno-cse #-}
{-# OPTIONS_GHC -Wno-incomplete-uni-patterns #-}
-- | Tests of convolution and disparity cost volume defined using the build
-- operation of ranked tensors.
module TestConvSimplified (testTrees) where

import Prelude

import Control.Exception.Assert.Sugar
import GHC.Exts (IsList (..))
import GHC.TypeLits (KnownNat)
import Test.Tasty
import Test.Tasty.HUnit hiding (assert)

import Data.Array.Nested qualified as Nested
import Data.Array.Nested.Ranked.Shape

import HordeAd
import HordeAd.Core.AstEnv
import HordeAd.Core.AstFreshId (resetVarCounter)
import HordeAd.Core.AstInterpret
import HordeAd.Core.AstTools
import HordeAd.Core.CarriersAst
import HordeAd.Core.Delta
import HordeAd.Core.Ops
import HordeAd.Core.OpsAst

import CrossTesting
import EqEpsilon

testTrees :: [TestTree]
testTrees =
  [ testCase "KonstG0Rev" testKonstG0Rev
  , testCase "KonstG0Tiny1" testKonstG0Tiny1
  , testCase "KonstG0TinyS" testKonstG0TinyS
  , testCase "KonstG0TinyA" testKonstG0TinyA
  , testCase "KonstG0LittleA" testKonstG0LittleA
  , testCase "Replicate0Rev" testReplicate0Rev
  , testCase "Replicate0Tiny1" testReplicate0Tiny1
  , testCase "Replicate0TinyS" testReplicate0TinyS
  , testCase "Replicate0TinyA" testReplicate0TinyA
  , testCase "Replicate0LittleA" testReplicate0LittleA
  , testCase "Konst5LittleB" testKonst5LittleB
  , testCase "Konst5LittleC" testKonst5LittleC
  , testCase "Konst5BigB" testKonst5BigB
  , testCase "KonstNotBigB" testKonstNotBigB
  , testCase "Konst5BigC" testKonst5BigC
  , testCase "KonstNotBigC" testKonstNotBigC
  , testCase "Konst5LittleB128b" testKonst5LittleB128b
  , testCase "Konst5LittleC128b" testKonst5LittleC128b
  , testCase "Konst5BigB128b" testKonst5BigB128b
  , testCase "KonstNotBigB128b" testKonstNotBigB128b
  , testCase "Konst5BigC128b" testKonst5BigC128b
  , testCase "KonstNotBigC128b" testKonstNotBigC128b
  , testCase "Konst5LittleB128c" testKonst5LittleB128c
  , testCase "Konst5LittleC128c" testKonst5LittleC128c
  , testCase "Konst5BigB128c" testKonst5BigB128c
  , testCase "KonstNotBigB128c" testKonstNotBigB128c
  , testCase "Konst5BigC128c" testKonst5BigC128c
  , testCase "KonstNotBigC128c" testKonstNotBigC128c
--  , testCase "Konst5LittleB128bc" testKonst5LittleB128bc
--  , testCase "Konst5LittleC128bc" testKonst5LittleC128bc
--  , testCase "Konst5BigB128bc" testKonst5BigB128bc
--  , testCase "KonstNotBigB128cb" testKonstNotBigB128cb
--  , testCase "Konst5BigC128cb" testKonst5BigC128cb
--  , testCase "KonstNotBigC128cb" testKonstNotBigC128cb
  , testCase "Replicate0RevLaborious" testReplicate0RevLaborious
  , testCase "Replicate0Tiny1Laborious" testReplicate0Tiny1Laborious
  , testCase "Replicate0TinySLaborious" testReplicate0TinySLaborious
  , testCase "Replicate0TinyALaborious" testReplicate0TinyALaborious
  , testCase "Replicate0LittleALaborious" testReplicate0LittleALaborious
  , testCase "Konst5LittleBLaborious" testKonst5LittleBLaborious
  , testCase "Konst5LittleCLaborious" testKonst5LittleCLaborious
  , testCase "Konst5BigBLaborious" testKonst5BigBLaborious
  , testCase "KonstNotBigBLaborious" testKonstNotBigBLaborious
  , testCase "Konst5BigCLaborious" testKonst5BigCLaborious
  , testCase "KonstNotBigCLaborious" testKonstNotBigCLaborious
  , testCase "Konst5LittleBLaborious128b" testKonst5LittleBLaborious128b
  , testCase "Konst5LittleCLaborious128b" testKonst5LittleCLaborious128b
--  , testCase "Konst5BigBLaborious128b" testKonst5BigBLaborious128b
--  , testCase "KonstNotBigBLaborious128b" testKonstNotBigBLaborious128b
--  , testCase "Konst5BigCLaborious128b" testKonst5BigCLaborious128b
--  , testCase "KonstNotBigCLaborious128b" testKonstNotBigCLaborious128b
  , testCase "Konst5LittleBLaborious128c" testKonst5LittleBLaborious128c
  , testCase "Konst5LittleCLaborious128c" testKonst5LittleCLaborious128c
--  , testCase "Konst5BigBLaborious128c" testKonst5BigBLaborious128c
--  , testCase "KonstNotBigBLaborious128c" testKonstNotBigBLaborious128c
--  , testCase "Konst5BigCLaborious128c" testKonst5BigCLaborious128c
--  , testCase "KonstNotBigCLaborious128c" testKonstNotBigCLaborious128c
--  , testCase "Konst5LittleBLaborious128bc" testKonst5LittleBLaborious128bc
--  , testCase "Konst5LittleCLaborious128bc" testKonst5LittleCLaborious128bc
--  , testCase "Konst5BigBLaborious128bc" testKonst5BigBLaborious128bc
--  , testCase "KonstNotBigBLaborious128cb" testKonstNotBigBLaborious128cb
--  , testCase "Konst5BigCLaborious128cb" testKonst5BigCLaborious128cb
--  , testCase "KonstNotBigCLaborious128cb" testKonstNotBigCLaborious128cb
--  , testCase "Replicate0RevPadded" testReplicate0RevPadded
  , testCase "Replicate0Tiny1Padded" testReplicate0Tiny1Padded
  , testCase "Replicate0TinySPadded" testReplicate0TinySPadded
  , testCase "Replicate0TinyAPadded" testReplicate0TinyAPadded
  , testCase "Replicate0LittleAPadded" testReplicate0LittleAPadded
--  , testCase "Konst5LittleBPadded" testKonst5LittleBPadded
--  , testCase "Konst5LittleCPadded" testKonst5LittleCPadded
--  , testCase "Konst5BigBPadded" testKonst5BigBPadded
--  , testCase "KonstNotBigBPadded" testKonstNotBigBPadded
--  , testCase "Konst5BigCPadded" testKonst5BigCPadded
--  , testCase "KonstNotBigCPadded" testKonstNotBigCPadded
--  , testCase "Konst5LittleBPadded128b" testKonst5LittleBPadded128b
--  , testCase "Konst5LittleCPadded128b" testKonst5LittleCPadded128b
--  , testCase "Konst5BigBPadded128b" testKonst5BigBPadded128b
--  , testCase "KonstNotBigBPadded128b" testKonstNotBigBPadded128b
--  , testCase "Konst5BigCPadded128b" testKonst5BigCPadded128b
--  , testCase "KonstNotBigCPadded128b" testKonstNotBigCPadded128b
--  , testCase "Konst5LittleBPadded128c" testKonst5LittleBPadded128c
--  , testCase "Konst5LittleCPadded128c" testKonst5LittleCPadded128c
--  , testCase "Konst5BigBPadded128c" testKonst5BigBPadded128c
--  , testCase "KonstNotBigBPadded128c" testKonstNotBigBPadded128c
--  , testCase "Konst5BigCPadded128c" testKonst5BigCPadded128c
--  , testCase "KonstNotBigCPadded128c" testKonstNotBigCPadded128c
--  , testCase "Konst5LittleBPadded128bc" testKonst5LittleBPadded128bc
--  , testCase "Konst5LittleCPadded128bc" testKonst5LittleCPadded128bc
--  , testCase "Konst5BigBPadded128bc" testKonst5BigBPadded128bc
--  , testCase "KonstNotBigBPadded128cb" testKonstNotBigBPadded128cb
--  , testCase "Konst5BigCPadded128cb" testKonst5BigCPadded128cb
--  , testCase "KonstNotBigCPadded128cb" testKonstNotBigCPadded128cb
  , testCase "disparityKonst" test_disparityKonst
  , testCase "disparityKonst2" test_disparityKonst2
  , testCase "disparitySmall" test_disparitySmall
  , testCase "ConvTomsSliceRev" testTomsSliceRev
  , testCase "ConvTomsSlice" testTomsSlice
  , testCase "ConvTomsSlicePP" testTomsSlicePP
  , testCase "minimizedCNNOPP0c" testCNNOPP0c
  , testCase "minimizedCNNOPP0b" testCNNOPP0b
  , testCase "minimizedCNNOPP1e" testCNNOPP1e
  , testCase "minimizedCNNOPP2" testCNNOPP2
  , testCase "minimizedCNNOPP2b" testCNNOPP2b
--  , testCase "minimizedCNNOPP3" testCNNOPP3
  , testCase "minimizedCNNOPP3b" testCNNOPP3b
  , testCase "minimizedCNNOPP4" testCNNOPP4
  , testCase "minimizedCNNOPP4b" testCNNOPP4b
  , testCase "minimizedCNNOPP5" testCNNOPP5
  , testCase "minimizedCNNOPP5b" testCNNOPP5b
  , testCase "minimizedCNNOPP6" testCNNOPP6
  , testCase "minimizedCNNOPP6b" testCNNOPP6b
  , testCase "minimizedCNNOPP7" testCNNOPP7
  , testCase "minimizedCNNOPP7b" testCNNOPP7b
--  , testCase "minimizedPaddedCNNOPP0c" testPaddedCNNOPP0c
--  , testCase "minimizedPaddedCNNOPP0b" testPaddedCNNOPP0b
--  , testCase "minimizedPaddedCNNOPP1e" testPaddedCNNOPP1e
  , testCase "minimizedPaddedCNNOPP1b" testPaddedCNNOPP1b
  , testCase "minimizedPaddedCNNOPPLet" testPaddedCNNOPPLet
  , testCase "minimizedPaddedCNNOPPLet2" testPaddedCNNOPPLet2
--  , testCase "minimizedPaddedCNNOPP2" testPaddedCNNOPP2
  , testCase "minimizedCNNOPP0cW" testCNNOPP0cW
  , testCase "minimizedCNNOPP0bW" testCNNOPP0bW
  , testCase "minimizedCNNOPP1bW" testCNNOPP1bW
  , testCase "minimizedCNNOPP4bW" testCNNOPP4bW
  , testCase "minimizedCNNOPP4bD" testCNNOPP4bD
  , testCase "minimizedCNNOPP5aW" testCNNOPP5aW
  , testCase "minimizedCNNOPP5bW" testCNNOPP5bW
  , testCase "minimizedCNNOPP5cW" testCNNOPP5cW
  , testCase "minimizedCNNOPP5dW" testCNNOPP5dW
  ]

-- The examples reproduced and transformed in this file are borrowed
-- from https://github.com/benl23x5/adops.
-- Here they are defined using ranked tensors.

-- * A non-laborious version (depends on indexing OOB giving 0 consistently)

conv2d1
  :: (ADReady target, GoodScalar r, Differentiable r)
  => target (TKR 4 r) -> target (TKR 4 r)
conv2d1 = conv2dUnpadded $ rconcrete $ Nested.rfromListPrimLinear (fromList [1, 1, 1, 1]) [-0.2]

conv2dA
  :: (ADReady target, GoodScalar r, Differentiable r)
  => target (TKR 4 r) -> target (TKR 4 r)
conv2dA = conv2dUnpadded $ rconcrete $ Nested.rfromListPrimLinear (fromList [1, 2, 1, 1]) [-0.2, 25.0003]

conv2dB
  :: (ADReady target, GoodScalar r, Differentiable r)
  => target (TKR 4 r) -> target (TKR 4 r)
conv2dB = conv2dUnpadded (rconcrete $ unConcrete t16b)

testKonstG0Rev :: Assertion
testKonstG0Rev =
  assertEqualUpToEpsilon 1e-4
    (rconcrete $ Nested.rfromListPrimLinear [2, 2, 2, 2] [18.1,29.1,32.1,40.1,582932.0,582934.99432,582597.1,582625.8943200001,18.1,29.1,32.1,40.1,582932.0,582934.99432,582597.1,582625.8943200001])
    (grad (kfromR . rsum0 @4 @(TKScalar Double) . conv2dB) (rrepl [2, 2, 2, 2] 0))

testKonstG0Tiny1 :: Assertion
testKonstG0Tiny1 =
  assertEqualUpToEpsilon' 1e-10
    (ringestData [1, 1, 1, 1] [-0.2])
    (rev' @Double @4 conv2d1 (rrepl [1, 1, 1, 1] 0))

testKonstG0TinyS :: Assertion
testKonstG0TinyS =
  assertEqualUpToEpsilon' 1e-10
    (ringestData [1, 1, 1, 1] [582665.99432])
    (rev' @Double @4
          (conv2dUnpadded $ rreplicate0N [1, 1, 1, 1] (rsum0 (rconcrete $ unConcrete t16b)))
          (ringestData [1, 1, 1, 1] [0]))

testKonstG0TinyA :: Assertion
testKonstG0TinyA =
  assertEqualUpToEpsilon' 1e-10
    (ringestData [1, 2, 1, 1] [-0.2,25.0003])
    (rev' @Double @4 conv2dA (rrepl [1, 2, 1, 1] 0))

testKonstG0LittleA :: Assertion
testKonstG0LittleA =
  assertEqualUpToEpsilon' 1e-10
    (ringestData [2, 2, 2, 2] [-0.2,-0.2,-0.2,-0.2,25.0003,25.0003,25.0003,25.0003,-0.2,-0.2,-0.2,-0.2,25.0003,25.0003,25.0003,25.0003])
    (rev' @Double @4 conv2dA (rrepl [2, 2, 2, 2] 0))

conv2dC
  :: (ADReady target, GoodScalar r, Differentiable r)
  => target (TKR 4 r) -> target (TKR 4 r)
conv2dC = flip conv2dUnpadded (rconcrete $ unConcrete t16b)

conv2dB128b
  :: (ADReady target, GoodScalar r, Differentiable r)
  => target (TKR 4 r) -> target (TKR 4 r)
conv2dB128b = conv2dUnpadded (rconcrete $ unConcrete t128b)

conv2dC128b
  :: (ADReady target, GoodScalar r, Differentiable r)
  => target (TKR 4 r) -> target (TKR 4 r)
conv2dC128b = flip conv2dUnpadded (rconcrete $ unConcrete t128b)

conv2dB128c
  :: (ADReady target, GoodScalar r, Differentiable r)
  => target (TKR 4 r) -> target (TKR 4 r)
conv2dB128c = conv2dUnpadded (rconcrete $ unConcrete t128c)

conv2dC128c
  :: (ADReady target, GoodScalar r, Differentiable r)
  => target (TKR 4 r) -> target (TKR 4 r)
conv2dC128c = flip conv2dUnpadded (rconcrete $ unConcrete t128c)

testReplicate0Rev :: Assertion
testReplicate0Rev =
  assertEqualUpToEpsilon 1e-4
    (rconcrete $ Nested.rfromListPrimLinear [2, 2, 2, 2] [18.1,29.1,32.1,40.1,582932.0,582934.99432,582597.1,582625.8943200001,18.1,29.1,32.1,40.1,582932.0,582934.99432,582597.1,582625.8943200001])
    (grad (kfromR . rsum0 @4 @(TKScalar Double) . conv2dB) (rrepl [2, 2, 2, 2] 0))

testReplicate0Tiny1 :: Assertion
testReplicate0Tiny1 =
  assertEqualUpToEpsilon' 1e-10
    (ringestData [1, 1, 1, 1] [-0.2])
    (rev' @Double @4 conv2d1 (rrepl [1, 1, 1, 1] 0))

testReplicate0TinyS :: Assertion
testReplicate0TinyS =
  assertEqualUpToEpsilon' 1e-10
    (ringestData [1, 1, 1, 1] [582665.99432])
    (rev' @Double @4
          (conv2dUnpadded $ rreplicate0N [1, 1, 1, 1] (rsum0 (rconcrete $ unConcrete t16b)))
          (ringestData [1, 1, 1, 1] [0]))

testReplicate0TinyA :: Assertion
testReplicate0TinyA =
  assertEqualUpToEpsilon' 1e-10
    (ringestData [1, 2, 1, 1] [-0.2,25.0003])
    (rev' @Double @4 conv2dA (rrepl [1, 2, 1, 1] 0))

testReplicate0LittleA :: Assertion
testReplicate0LittleA =
  assertEqualUpToEpsilon' 1e-10
    (ringestData [2, 2, 2, 2] [-0.2,-0.2,-0.2,-0.2,25.0003,25.0003,25.0003,25.0003,-0.2,-0.2,-0.2,-0.2,25.0003,25.0003,25.0003,25.0003])
    (rev' @Double @4 conv2dA (rrepl [2, 2, 2, 2] 0))

-- with data t16

testKonst5LittleB :: Assertion
testKonst5LittleB =
  assertEqualUpToEpsilon' 1e-8
    (ringestData [2, 2, 2, 2] [18.1,29.1,32.1,40.1,582932.0,582934.99432,582597.1,582625.8943200001,18.1,29.1,32.1,40.1,582932.0,582934.99432,582597.1,582625.8943200001])
    (rev' @Double @4 conv2dB (rreplicate0N [2, 2, 2, 2] (rscalar 5)))

testKonst5LittleC :: Assertion
testKonst5LittleC =
  assertEqualUpToEpsilon' 1e-8
    (ringestData [2, 2, 2, 2] [40.1,8.0,11.0,-3.0,582625.89432,28.79432,-309.09999999999997,25.8,40.1,8.0,11.0,-3.0,582625.89432,28.79432,-309.09999999999997,25.8])
    (rev' @Double @4 conv2dC (rreplicate0N [2, 2, 2, 2] (rscalar 5)))

testKonst5BigB :: Assertion
testKonst5BigB =
  assertEqualUpToEpsilon' 1e-8
    (ringestData [3, 2, 4, 2] [18.1,29.1,32.1,40.1,32.1,40.1,32.1,40.1,582932.0,582934.99432,582597.1,582625.8943200001,582597.1,582625.8943200001,582597.1,582625.8943200001,18.1,29.1,32.1,40.1,32.1,40.1,32.1,40.1,582932.0,582934.99432,582597.1,582625.8943200001,582597.1,582625.8943200001,582597.1,582625.8943200001,18.1,29.1,32.1,40.1,32.1,40.1,32.1,40.1,582932.0,582934.99432,582597.1,582625.8943200001,582597.1,582625.8943200001,582597.1,582625.8943200001])
    (rev' @Double @4 conv2dB (rreplicate0N [3, 2, 4, 2] (rscalar 5)))

-- The gradient is the same as above, because one argument is the same
-- and convolution is linear.
testKonstNotBigB :: Assertion
testKonstNotBigB =
  assertEqualUpToEpsilon' 1e-8
    (ringestData [3, 2, 4, 2] [18.1,29.1,32.1,40.1,32.1,40.1,32.1,40.1,582932.0,582934.99432,582597.1,582625.8943200001,582597.1,582625.8943200001,582597.1,582625.8943200001,18.1,29.1,32.1,40.1,32.1,40.1,32.1,40.1,582932.0,582934.99432,582597.1,582625.8943200001,582597.1,582625.8943200001,582597.1,582625.8943200001,18.1,29.1,32.1,40.1,32.1,40.1,32.1,40.1,582932.0,582934.99432,582597.1,582625.8943200001,582597.1,582625.8943200001,582597.1,582625.8943200001])
    (rev' @Double @4 conv2dB
          (rfromList0N [3, 2, 4, 2] (map rscalar [37, 36 .. -10])))

testKonst5BigC :: Assertion
testKonst5BigC =
  assertEqualUpToEpsilon' 1e-8
    (ringestData [3, 2, 4, 2] [40.1,8.0,11.0,-3.0,0.0,0.0,0.0,0.0,582625.8943200001,28.794320000000003,-309.09999999999997,25.8,0.0,0.0,0.0,0.0,40.1,8.0,11.0,-3.0,0.0,0.0,0.0,0.0,582625.8943200001,28.794320000000003,-309.09999999999997,25.8,0.0,0.0,0.0,0.0,40.1,8.0,11.0,-3.0,0.0,0.0,0.0,0.0,582625.8943200001,28.794320000000003,-309.09999999999997,25.8,0.0,0.0,0.0,0.0])
    (rev' @Double @4 conv2dC (rreplicate0N [3, 2, 4, 2] (rscalar 5)))

-- The gradient is the same as above, because one argument is the same
-- and convolution is linear.
testKonstNotBigC :: Assertion
testKonstNotBigC =
  assertEqualUpToEpsilon' 1e-8
    (ringestData [3, 2, 4, 2] [40.1,8.0,11.0,-3.0,0.0,0.0,0.0,0.0,582625.8943200001,28.794320000000003,-309.09999999999997,25.8,0.0,0.0,0.0,0.0,40.1,8.0,11.0,-3.0,0.0,0.0,0.0,0.0,582625.8943200001,28.794320000000003,-309.09999999999997,25.8,0.0,0.0,0.0,0.0,40.1,8.0,11.0,-3.0,0.0,0.0,0.0,0.0,582625.8943200001,28.794320000000003,-309.09999999999997,25.8,0.0,0.0,0.0,0.0])
    (rev' @Double @4 conv2dC
          (rfromList0N [3, 2, 4, 2] (map rscalar [37, 36 .. -10])))

-- with data t128b

testKonst5LittleB128b :: Assertion
testKonst5LittleB128b =
  assertEqualUpToEpsilon' 1e-8
    (ringestData [2, 2, 2, 2] [112.3003,251.5006,209.49462,482.69492000000014,3.000000000000032,65.90000000000003,164.10000000000002,365.89432000010004,112.3003,251.5006,209.49462,482.69492000000014,3.000000000000032,65.90000000000003,164.10000000000002,365.89432000010004])
    (rev' @Double @4 conv2dB128b (rreplicate0N [2, 2, 2, 2] (rscalar 5)))

testKonst5LittleC128b :: Assertion
testKonst5LittleC128b =
  assertEqualUpToEpsilon' 1e-8
    (ringestData [2, 2, 2, 2] [1627.8210700004993,1571.2321300004994,1132.9261600005002,1188.6375200005,2725.0393200008984,1831.7390200008983,2551.139320000898,1660.8390200008987,1627.8210700004993,1571.2321300004994,1132.9261600005002,1188.6375200005,2725.0393200008984,1831.7390200008983,2551.139320000898,1660.8390200008987])
    (rev' @Double @4 conv2dC128b (rreplicate0N [2, 2, 2, 2] (rscalar 5)))

testKonst5BigB128b :: Assertion
testKonst5BigB128b =
  assertEqualUpToEpsilon' 1e-8
    (ringestData [3, 2, 4, 2] [112.3003,251.5006,209.49462,482.69492000000014,229.49462000000003,610.5892400000002,56.58894000000004,580.6778800001001,3.000000000000032,65.90000000000003,164.10000000000002,365.89432000010004,667.2003000000001,1060.8778800002,893.3003,1465.6665200003993,112.3003,251.5006,209.49462,482.69492000000014,229.49462000000003,610.5892400000002,56.58894000000004,580.6778800001001,3.000000000000032,65.90000000000003,164.10000000000002,365.89432000010004,667.2003000000001,1060.8778800002,893.3003,1465.6665200003993,112.3003,251.5006,209.49462,482.69492000000014,229.49462000000003,610.5892400000002,56.58894000000004,580.6778800001001,3.000000000000032,65.90000000000003,164.10000000000002,365.89432000010004,667.2003000000001,1060.8778800002,893.3003,1465.6665200003993])
    (rev' @Double @4 conv2dB128b (rreplicate0N [3, 2, 4, 2] (rscalar 5)))

-- The gradient is the same as above, because one argument is the same
-- and convolution is linear.
testKonstNotBigB128b :: Assertion
testKonstNotBigB128b =
  assertEqualUpToEpsilon' 1e-8
    (ringestData [3, 2, 4, 2] [112.3003,251.5006,209.49462,482.69492000000014,229.49462000000003,610.5892400000002,56.58894000000004,580.6778800001001,3.000000000000032,65.90000000000003,164.10000000000002,365.89432000010004,667.2003000000001,1060.8778800002,893.3003,1465.6665200003993,112.3003,251.5006,209.49462,482.69492000000014,229.49462000000003,610.5892400000002,56.58894000000004,580.6778800001001,3.000000000000032,65.90000000000003,164.10000000000002,365.89432000010004,667.2003000000001,1060.8778800002,893.3003,1465.6665200003993,112.3003,251.5006,209.49462,482.69492000000014,229.49462000000003,610.5892400000002,56.58894000000004,580.6778800001001,3.000000000000032,65.90000000000003,164.10000000000002,365.89432000010004,667.2003000000001,1060.8778800002,893.3003,1465.6665200003993])
    (rev' @Double @4 conv2dB128b
          (rfromList0N [3, 2, 4, 2] (map rscalar [37, 36 .. -10])))

testKonst5BigC128b :: Assertion
testKonst5BigC128b =
  assertEqualUpToEpsilon' 1e-8
    (ringestData [3, 2, 4, 2] [1627.8210700004993,1571.2321300004994,1132.9261600005002,1188.6375200005,675.7488800003999,828.6545600004001,215.6659200003,388.5716000003,2725.0393200008984,1831.7390200008983,2551.139320000898,1660.8390200008987,1903.750080000699,1174.5497800006997,854.9778800004001,628.8778800004001,1627.8210700004993,1571.2321300004994,1132.9261600005002,1188.6375200005,675.7488800003999,828.6545600004001,215.6659200003,388.5716000003,2725.0393200008984,1831.7390200008983,2551.139320000898,1660.8390200008987,1903.750080000699,1174.5497800006997,854.9778800004001,628.8778800004001,1627.8210700004993,1571.2321300004994,1132.9261600005002,1188.6375200005,675.7488800003999,828.6545600004001,215.6659200003,388.5716000003,2725.0393200008984,1831.7390200008983,2551.139320000898,1660.8390200008987,1903.750080000699,1174.5497800006997,854.9778800004001,628.8778800004001])
    (rev' @Double @4 conv2dC128b (rreplicate0N [3, 2, 4, 2] (rscalar 5)))

-- The gradient is the same as above, because one argument is the same
-- and convolution is linear.
testKonstNotBigC128b :: Assertion
testKonstNotBigC128b =
  assertEqualUpToEpsilon' 1e-8
    (ringestData [3, 2, 4, 2] [1627.8210700004993,1571.2321300004994,1132.9261600005002,1188.6375200005,675.7488800003999,828.6545600004001,215.6659200003,388.5716000003,2725.0393200008984,1831.7390200008983,2551.139320000898,1660.8390200008987,1903.750080000699,1174.5497800006997,854.9778800004001,628.8778800004001,1627.8210700004993,1571.2321300004994,1132.9261600005002,1188.6375200005,675.7488800003999,828.6545600004001,215.6659200003,388.5716000003,2725.0393200008984,1831.7390200008983,2551.139320000898,1660.8390200008987,1903.750080000699,1174.5497800006997,854.9778800004001,628.8778800004001,1627.8210700004993,1571.2321300004994,1132.9261600005002,1188.6375200005,675.7488800003999,828.6545600004001,215.6659200003,388.5716000003,2725.0393200008984,1831.7390200008983,2551.139320000898,1660.8390200008987,1903.750080000699,1174.5497800006997,854.9778800004001,628.8778800004001])
    (rev' @Double @4 conv2dC128b
          (rfromList0N [3, 2, 4, 2] (map rscalar [37, 36 .. -10])))

-- with data t128c

testKonst5LittleB128c :: Assertion
testKonst5LittleB128c =
  assertEqualUpToEpsilon' 1e-8
    (ringestData [2, 2, 2, 2] [54.100300000000004,111.20060000000001,119.09462,270.29492000000005,58.2,140.3,90.4,212.4,54.100300000000004,111.20060000000001,119.09462,270.29492000000005,58.2,140.3,90.4,212.4])
    (rev' @Double @4 conv2dB128c (rreplicate0N [2, 2, 2, 2] (rscalar 5)))

testKonst5LittleC128c :: Assertion
testKonst5LittleC128c =
  assertEqualUpToEpsilon' 1e-8
    (ringestData [2, 2, 2, 2] [2640.8154000007976,1836.3264600007988,2412.414800000798,1662.026160000799,1712.044990000598,1566.644690000599,1445.5506800005985,1358.3503800005992,2640.8154000007976,1836.3264600007988,2412.414800000798,1662.026160000799,1712.044990000598,1566.644690000599,1445.5506800005985,1358.3503800005992])
    (rev' @Double @4 conv2dC128c (rreplicate0N [2, 2, 2, 2] (rscalar 5)))

testKonst5BigB128c :: Assertion
testKonst5BigB128c =
  assertEqualUpToEpsilon' 1e-8
    (ringestData [3, 2, 4, 2] [54.100300000000004,111.20060000000001,119.09462,270.29492000000005,109.09462000000002,318.19492,174.08894000000004,477.28924000000006,58.2,140.3,90.4,212.4,120.4,292.39432000000005,-117.5,103.38864000010005,54.100300000000004,111.20060000000001,119.09462,270.29492000000005,109.09462000000002,318.19492,174.08894000000004,477.28924000000006,58.2,140.3,90.4,212.4,120.4,292.39432000000005,-117.5,103.38864000010005,54.100300000000004,111.20060000000001,119.09462,270.29492000000005,109.09462000000002,318.19492,174.08894000000004,477.28924000000006,58.2,140.3,90.4,212.4,120.4,292.39432000000005,-117.5,103.38864000010005])
    (rev' @Double @4 conv2dB128c (rreplicate0N [3, 2, 4, 2] (rscalar 5)))

-- The gradient is the same as above, because one argument is the same
-- and convolution is linear.
testKonstNotBigB128c :: Assertion
testKonstNotBigB128c =
  assertEqualUpToEpsilon' 1e-8
    (ringestData [3, 2, 4, 2] [54.100300000000004,111.20060000000001,119.09462,270.29492000000005,109.09462000000002,318.19492,174.08894000000004,477.28924000000006,58.2,140.3,90.4,212.4,120.4,292.39432000000005,-117.5,103.38864000010005,54.100300000000004,111.20060000000001,119.09462,270.29492000000005,109.09462000000002,318.19492,174.08894000000004,477.28924000000006,58.2,140.3,90.4,212.4,120.4,292.39432000000005,-117.5,103.38864000010005,54.100300000000004,111.20060000000001,119.09462,270.29492000000005,109.09462000000002,318.19492,174.08894000000004,477.28924000000006,58.2,140.3,90.4,212.4,120.4,292.39432000000005,-117.5,103.38864000010005])
    (rev' @Double @4 conv2dB128c
          (rfromList0N [3, 2, 4, 2] (map rscalar [37, 36 .. -10])))

testKonst5BigC128c :: Assertion
testKonst5BigC128c =
  assertEqualUpToEpsilon' 1e-8
    (ringestData [3, 2, 4, 2] [2640.8154000007976,1836.3264600007988,2412.414800000798,1662.026160000799,2121.6375200006987,1436.2432000006995,1953.5375200006988,1258.1432000006998,1712.044990000598,1566.644690000599,1445.5506800005985,1358.3503800005992,1279.150680000599,1224.1503800005996,987.1677200004992,962.1674200005002,2640.8154000007976,1836.3264600007988,2412.414800000798,1662.026160000799,2121.6375200006987,1436.2432000006995,1953.5375200006988,1258.1432000006998,1712.044990000598,1566.644690000599,1445.5506800005985,1358.3503800005992,1279.150680000599,1224.1503800005996,987.1677200004992,962.1674200005002,2640.8154000007976,1836.3264600007988,2412.414800000798,1662.026160000799,2121.6375200006987,1436.2432000006995,1953.5375200006988,1258.1432000006998,1712.044990000598,1566.644690000599,1445.5506800005985,1358.3503800005992,1279.150680000599,1224.1503800005996,987.1677200004992,962.1674200005002])
    (rev' @Double @4 conv2dC128c (rreplicate0N [3, 2, 4, 2] (rscalar 5)))

-- The gradient is the same as above, because one argument is the same
-- and convolution is linear.
testKonstNotBigC128c :: Assertion
testKonstNotBigC128c =
  assertEqualUpToEpsilon' 1e-8
    (ringestData [3, 2, 4, 2] [2640.8154000007976,1836.3264600007988,2412.414800000798,1662.026160000799,2121.6375200006987,1436.2432000006995,1953.5375200006988,1258.1432000006998,1712.044990000598,1566.644690000599,1445.5506800005985,1358.3503800005992,1279.150680000599,1224.1503800005996,987.1677200004992,962.1674200005002,2640.8154000007976,1836.3264600007988,2412.414800000798,1662.026160000799,2121.6375200006987,1436.2432000006995,1953.5375200006988,1258.1432000006998,1712.044990000598,1566.644690000599,1445.5506800005985,1358.3503800005992,1279.150680000599,1224.1503800005996,987.1677200004992,962.1674200005002,2640.8154000007976,1836.3264600007988,2412.414800000798,1662.026160000799,2121.6375200006987,1436.2432000006995,1953.5375200006988,1258.1432000006998,1712.044990000598,1566.644690000599,1445.5506800005985,1358.3503800005992,1279.150680000599,1224.1503800005996,987.1677200004992,962.1674200005002])
    (rev' @Double @4 conv2dC128c
          (rfromList0N [3, 2, 4, 2] (map rscalar [37, 36 .. -10])))

-- with data t128b and t128c
{-
testKonst5LittleB128bc :: Assertion
testKonst5LittleB128bc =
  assertEqualUpToEpsilon' 1e-8
    (ringestData [2,2,8,4] [112.3003,251.5006,417.79492000000005,494.89491000000015,209.49462,482.69492000000014,778.9778800001002,952.0721900001002,229.49462000000003,610.5892400000002,1113.1722000001,1412.1551500001997,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,3.000000000000032,65.90000000000003,106.90000000000003,173.90000000000006,164.10000000000002,365.89432000010004,593.1946200001,821.2892400002004,667.2003000000001,1060.8778800002,1500.2781800001994,1870.0614400004986,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898,112.3003,251.5006,417.79492000000005,494.89491000000015,209.49462,482.69492000000014,778.9778800001002,952.0721900001002,229.49462000000003,610.5892400000002,1113.1722000001,1412.1551500001997,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,3.000000000000032,65.90000000000003,106.90000000000003,173.90000000000006,164.10000000000002,365.89432000010004,593.1946200001,821.2892400002004,667.2003000000001,1060.8778800002,1500.2781800001994,1870.0614400004986,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898])
    (rev' @Double @4 conv2dB128b t128c)

testKonst5LittleC128bc :: Assertion
testKonst5LittleC128bc =
  assertEqualUpToEpsilon' 1e-8
    (ringestData [2,2,8,4] [1627.8210700004993,1571.2321300004994,1047.1431900004002,393.6715900002,1132.9261600005002,1188.6375200005,803.7488800004002,316.57160000019996,675.7488800003999,828.6545600004001,577.7659200003001,220.57728000019998,215.6659200003,388.5716000003,245.5772800002,94.68864000010001,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,2725.0393200008984,1831.7390200008983,1259.3728000004999,568.6722000005001,2551.139320000898,1660.8390200008987,1151.3728000005,501.6722000005,1903.750080000699,1174.5497800006997,803.9778800004001,340.5775800004001,854.9778800004001,628.8778800004001,450.1892400002,198.8889400002,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1627.8210700004993,1571.2321300004994,1047.1431900004002,393.6715900002,1132.9261600005002,1188.6375200005,803.7488800004002,316.57160000019996,675.7488800003999,828.6545600004001,577.7659200003001,220.57728000019998,215.6659200003,388.5716000003,245.5772800002,94.68864000010001,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,2725.0393200008984,1831.7390200008983,1259.3728000004999,568.6722000005001,2551.139320000898,1660.8390200008987,1151.3728000005,501.6722000005,1903.750080000699,1174.5497800006997,803.9778800004001,340.5775800004001,854.9778800004001,628.8778800004001,450.1892400002,198.8889400002,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0])
    (rev' @Double @4 conv2dC128b t128c)

testKonst5BigB128bc :: Assertion
testKonst5BigB128bc =
  assertEqualUpToEpsilon' 1e-8
    (ringestData [2,2,8,4] [112.3003,251.5006,417.79492000000005,494.89491000000015,209.49462,482.69492000000014,778.9778800001002,952.0721900001002,229.49462000000003,610.5892400000002,1113.1722000001,1412.1551500001997,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,3.000000000000032,65.90000000000003,106.90000000000003,173.90000000000006,164.10000000000002,365.89432000010004,593.1946200001,821.2892400002004,667.2003000000001,1060.8778800002,1500.2781800001994,1870.0614400004986,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898,112.3003,251.5006,417.79492000000005,494.89491000000015,209.49462,482.69492000000014,778.9778800001002,952.0721900001002,229.49462000000003,610.5892400000002,1113.1722000001,1412.1551500001997,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,3.000000000000032,65.90000000000003,106.90000000000003,173.90000000000006,164.10000000000002,365.89432000010004,593.1946200001,821.2892400002004,667.2003000000001,1060.8778800002,1500.2781800001994,1870.0614400004986,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898])
    (rev' @Double @4 conv2dB128b t128c)

-- The gradient is the same as above, because one argument is the same
-- and convolution is linear.
testKonstNotBigB128cb :: Assertion
testKonstNotBigB128cb =
  assertEqualUpToEpsilon' 1e-8
    (ringestData [4,2,4,4] [54.100300000000004,111.20060000000001,191.4006,228.4006,119.09462,270.29492000000005,435.28356000009995,519.1778800001,109.09462000000002,318.19492,563.3835600001,687.2778800001003,174.08894000000004,477.28924000000006,774.2665200002001,931.9551600002003,58.2,140.3,226.39432,266.49431000000004,90.4,212.4,343.69432000000006,432.89431000000013,120.4,292.39432000000005,549.78864,724.8772700001001,-117.5,103.38864000010005,459.88296000009996,695.8659100003002,54.100300000000004,111.20060000000001,191.4006,228.4006,119.09462,270.29492000000005,435.28356000009995,519.1778800001,109.09462000000002,318.19492,563.3835600001,687.2778800001003,174.08894000000004,477.28924000000006,774.2665200002001,931.9551600002003,58.2,140.3,226.39432,266.49431000000004,90.4,212.4,343.69432000000006,432.89431000000013,120.4,292.39432000000005,549.78864,724.8772700001001,-117.5,103.38864000010005,459.88296000009996,695.8659100003002,54.100300000000004,111.20060000000001,191.4006,228.4006,119.09462,270.29492000000005,435.28356000009995,519.1778800001,109.09462000000002,318.19492,563.3835600001,687.2778800001003,174.08894000000004,477.28924000000006,774.2665200002001,931.9551600002003,58.2,140.3,226.39432,266.49431000000004,90.4,212.4,343.69432000000006,432.89431000000013,120.4,292.39432000000005,549.78864,724.8772700001001,-117.5,103.38864000010005,459.88296000009996,695.8659100003002,54.100300000000004,111.20060000000001,191.4006,228.4006,119.09462,270.29492000000005,435.28356000009995,519.1778800001,109.09462000000002,318.19492,563.3835600001,687.2778800001003,174.08894000000004,477.28924000000006,774.2665200002001,931.9551600002003,58.2,140.3,226.39432,266.49431000000004,90.4,212.4,343.69432000000006,432.89431000000013,120.4,292.39432000000005,549.78864,724.8772700001001,-117.5,103.38864000010005,459.88296000009996,695.8659100003002])
    (rev' @Double @4 conv2dB128c t128b)

testKonst5BigC128cb :: Assertion
testKonst5BigC128cb =
  assertEqualUpToEpsilon' 1e-8
    (ringestData [4,2,4,4] [2640.8154000007976,1836.3264600007988,1163.4488800005001,483.7716000003,2412.414800000798,1662.026160000799,1046.2488800005003,446.7716000003,2121.6375200006987,1436.2432000006995,914.5659200004003,399.8772800003,1953.5375200006988,1258.1432000006998,794.3659200004003,359.8772800003,1712.044990000598,1566.644690000599,1143.0671100004001,478.5721900004001,1445.5506800005985,1358.3503800005992,1016.8728000004002,438.47220000040005,1279.150680000599,1224.1503800005996,922.5728000004001,389.3722000004,987.1677200004992,962.1674200005002,710.5841600003,303.48356000030003,2640.8154000007976,1836.3264600007988,1163.4488800005001,483.7716000003,2412.414800000798,1662.026160000799,1046.2488800005003,446.7716000003,2121.6375200006987,1436.2432000006995,914.5659200004003,399.8772800003,1953.5375200006988,1258.1432000006998,794.3659200004003,359.8772800003,1712.044990000598,1566.644690000599,1143.0671100004001,478.5721900004001,1445.5506800005985,1358.3503800005992,1016.8728000004002,438.47220000040005,1279.150680000599,1224.1503800005996,922.5728000004001,389.3722000004,987.1677200004992,962.1674200005002,710.5841600003,303.48356000030003,2640.8154000007976,1836.3264600007988,1163.4488800005001,483.7716000003,2412.414800000798,1662.026160000799,1046.2488800005003,446.7716000003,2121.6375200006987,1436.2432000006995,914.5659200004003,399.8772800003,1953.5375200006988,1258.1432000006998,794.3659200004003,359.8772800003,1712.044990000598,1566.644690000599,1143.0671100004001,478.5721900004001,1445.5506800005985,1358.3503800005992,1016.8728000004002,438.47220000040005,1279.150680000599,1224.1503800005996,922.5728000004001,389.3722000004,987.1677200004992,962.1674200005002,710.5841600003,303.48356000030003,2640.8154000007976,1836.3264600007988,1163.4488800005001,483.7716000003,2412.414800000798,1662.026160000799,1046.2488800005003,446.7716000003,2121.6375200006987,1436.2432000006995,914.5659200004003,399.8772800003,1953.5375200006988,1258.1432000006998,794.3659200004003,359.8772800003,1712.044990000598,1566.644690000599,1143.0671100004001,478.5721900004001,1445.5506800005985,1358.3503800005992,1016.8728000004002,438.47220000040005,1279.150680000599,1224.1503800005996,922.5728000004001,389.3722000004,987.1677200004992,962.1674200005002,710.5841600003,303.48356000030003])
    (rev' @Double @4 conv2dC128c t128b)

-- The gradient is the same as above, because one argument is the same
-- and convolution is linear.
testKonstNotBigC128cb :: Assertion
testKonstNotBigC128cb =
  assertEqualUpToEpsilon' 1e-8
    (ringestData [4,2,4,4] [2640.8154000007976,1836.3264600007988,1163.4488800005001,483.7716000003,2412.414800000798,1662.026160000799,1046.2488800005003,446.7716000003,2121.6375200006987,1436.2432000006995,914.5659200004003,399.8772800003,1953.5375200006988,1258.1432000006998,794.3659200004003,359.8772800003,1712.044990000598,1566.644690000599,1143.0671100004001,478.5721900004001,1445.5506800005985,1358.3503800005992,1016.8728000004002,438.47220000040005,1279.150680000599,1224.1503800005996,922.5728000004001,389.3722000004,987.1677200004992,962.1674200005002,710.5841600003,303.48356000030003,2640.8154000007976,1836.3264600007988,1163.4488800005001,483.7716000003,2412.414800000798,1662.026160000799,1046.2488800005003,446.7716000003,2121.6375200006987,1436.2432000006995,914.5659200004003,399.8772800003,1953.5375200006988,1258.1432000006998,794.3659200004003,359.8772800003,1712.044990000598,1566.644690000599,1143.0671100004001,478.5721900004001,1445.5506800005985,1358.3503800005992,1016.8728000004002,438.47220000040005,1279.150680000599,1224.1503800005996,922.5728000004001,389.3722000004,987.1677200004992,962.1674200005002,710.5841600003,303.48356000030003,2640.8154000007976,1836.3264600007988,1163.4488800005001,483.7716000003,2412.414800000798,1662.026160000799,1046.2488800005003,446.7716000003,2121.6375200006987,1436.2432000006995,914.5659200004003,399.8772800003,1953.5375200006988,1258.1432000006998,794.3659200004003,359.8772800003,1712.044990000598,1566.644690000599,1143.0671100004001,478.5721900004001,1445.5506800005985,1358.3503800005992,1016.8728000004002,438.47220000040005,1279.150680000599,1224.1503800005996,922.5728000004001,389.3722000004,987.1677200004992,962.1674200005002,710.5841600003,303.48356000030003,2640.8154000007976,1836.3264600007988,1163.4488800005001,483.7716000003,2412.414800000798,1662.026160000799,1046.2488800005003,446.7716000003,2121.6375200006987,1436.2432000006995,914.5659200004003,399.8772800003,1953.5375200006988,1258.1432000006998,794.3659200004003,359.8772800003,1712.044990000598,1566.644690000599,1143.0671100004001,478.5721900004001,1445.5506800005985,1358.3503800005992,1016.8728000004002,438.47220000040005,1279.150680000599,1224.1503800005996,922.5728000004001,389.3722000004,987.1677200004992,962.1674200005002,710.5841600003,303.48356000030003])
    (rev' @Double @4 conv2dC128c t128b)
-}


-- * A laborious version (meaning, out of bounds indexing is handled explicitly)

-- | Unpadded full convolution,
--   where the output size is the same as the input size.
--
-- It guards the out of bounds indexing behind a conditional
-- to prevent changed values after vectorization,
-- but the guarding is no longer needed, so this is only for testing.
--
-- BTW, the indexing lower bounds in the code are spurious,
-- so they get simplified away in the resulting AST program.
conv2dUnpaddedL
  :: (ADReady target, GoodScalar r)
  => target (TKR 4 r) -> target (TKR 4 r) -> target (TKR 4 r)
conv2dUnpaddedL arrK arrA =
  let [nImgs, nCinpA, nAh, nAw] = rshape arrA
      [nCoutK, nCinpK, nKh, nKw] = rshape arrK
      nCinp = assert (nCinpA == nCinpK `blame` (nCinpA, nCinpK)) nCinpA
      shB = [nImgs, nCoutK, nAh, nAw]
      shK1 = [1, nCinp, nKh, nKw]
  in rbuild shB $ \case
    [iImg, iCout, iBh, iBw] ->
      let arrAt = slicezL shK1 arrA [iImg, 0, iBh, iBw]
          arrKt = slicezL shK1 arrK [iCout, 0, 0, 0]
      in rdot0 arrAt arrKt
    _ -> error "conv2dUnpaddedL: impossible pattern needlessly required"

-- | Slice a section out of a tensor,
--   given a base offset and shape of the section.
--
--   If the slice extends out side the source array then the corresponding
--   elements are set to zero.
slicezL
  :: (ADReady target, GoodScalar r, KnownNat n)
  => IShR n -> target (TKR n r) -> IxROf target n -> target (TKR n r)
slicezL shOut d ixBase =
  rbuild shOut $ \ixResult -> indexz0L d (ixrZipWith (+) ixBase ixResult)

-- | Retrieve the element at the given index,
--   returning zero for out of range indices.
--
-- Warning: this uses ix twice and within0 again uses it twice,
-- so this variant without tlet should be used only when it's known
-- that ix is of small constant size (e.g., if it contains conditionals
-- that compare big tensors or their minimal elements, it likely is not,
-- unless the tensors are under tlet and only variables representing them
-- are used).
indexz0L
  :: forall target r n. (ADReady target, GoodScalar r, KnownNat n)
  => target (TKR n r) -> IxROf target n -> target (TKR 0 r)
indexz0L d ix = ifH (within0 @target (rshape @target d) ix) (d ! ix) (rscalar 0)

-- | Given an index and shape, check if the index is fully within the shape.
-- Note that @ix@ is used twice, so should be shared outside.
within0
  :: forall target n. (ADReady target, KnownNat n)
  => IShR n -> IxROf target n -> BoolOf target
within0 sh ix =
  let within :: IntOf target -> IntOf target -> BoolOf target
      within i dim = 0 <=. i &&* dim >. i
  in foldr (&&*) true
     $ zipWith within (toList ix) (map fromIntegral $ toList sh)

conv2d1Laborious
  :: (ADReady target, GoodScalar r, Differentiable r)
  => target (TKR 4 r) -> target (TKR 4 r)
conv2d1Laborious = conv2dUnpaddedL $ rconcrete $ Nested.rfromListPrimLinear (fromList [1, 1, 1, 1]) [-0.2]

conv2dALaborious
  :: (ADReady target, GoodScalar r, Differentiable r)
  => target (TKR 4 r) -> target (TKR 4 r)
conv2dALaborious =
  conv2dUnpaddedL $ rconcrete $ Nested.rfromListPrimLinear (fromList [1, 2, 1, 1]) [-0.2, 25.0003]

conv2dBLaborious
  :: (ADReady target, GoodScalar r, Differentiable r)
  => target (TKR 4 r) -> target (TKR 4 r)
conv2dBLaborious = conv2dUnpaddedL (rconcrete $ unConcrete t16b)

conv2dCLaborious
  :: (ADReady target, GoodScalar r, Differentiable r)
  => target (TKR 4 r) -> target (TKR 4 r)
conv2dCLaborious = flip conv2dUnpaddedL (rconcrete $ unConcrete t16b)

conv2dBLaborious128b
  :: (ADReady target, GoodScalar r, Differentiable r)
  => target (TKR 4 r) -> target (TKR 4 r)
conv2dBLaborious128b = conv2dUnpaddedL (rconcrete $ unConcrete t128b)

conv2dCLaborious128b
  :: (ADReady target, GoodScalar r, Differentiable r)
  => target (TKR 4 r) -> target (TKR 4 r)
conv2dCLaborious128b = flip conv2dUnpaddedL (rconcrete $ unConcrete t128b)

conv2dBLaborious128c
  :: (ADReady target, GoodScalar r, Differentiable r)
  => target (TKR 4 r) -> target (TKR 4 r)
conv2dBLaborious128c = conv2dUnpaddedL (rconcrete $ unConcrete t128c)

conv2dCLaborious128c
  :: (ADReady target, GoodScalar r, Differentiable r)
  => target (TKR 4 r) -> target (TKR 4 r)
conv2dCLaborious128c = flip conv2dUnpaddedL (rconcrete $ unConcrete t128c)

testReplicate0RevLaborious :: Assertion
testReplicate0RevLaborious =
  assertEqualUpToEpsilon 1e-4
    (rconcrete $ Nested.rfromListPrimLinear [2, 2, 2, 2] [18.1,29.1,32.1,40.1,582932.0,582934.99432,582597.1,582625.8943200001,18.1,29.1,32.1,40.1,582932.0,582934.99432,582597.1,582625.8943200001])
    (grad (kfromR . rsum0 @4 @(TKScalar Double) . conv2dBLaborious) (rrepl [2, 2, 2, 2] 0))

testReplicate0Tiny1Laborious :: Assertion
testReplicate0Tiny1Laborious =
  assertEqualUpToEpsilon' 1e-10
    (ringestData [1, 1, 1, 1] [-0.2])
    (rev' @Double @4 conv2d1Laborious (rrepl [1, 1, 1, 1] 0))

testReplicate0TinySLaborious :: Assertion
testReplicate0TinySLaborious =
  assertEqualUpToEpsilon' 1e-10
    (ringestData [1, 1, 1, 1] [582665.99432])
    (rev' @Double @4
          (conv2dUnpaddedL $ rreplicate0N [1, 1, 1, 1] (rsum0 (rconcrete $ unConcrete t16b)))
          (ringestData [1, 1, 1, 1] [0]))

testReplicate0TinyALaborious :: Assertion
testReplicate0TinyALaborious =
  assertEqualUpToEpsilon' 1e-10
    (ringestData [1, 2, 1, 1] [-0.2,25.0003])
    (rev' @Double @4 conv2dALaborious (rrepl [1, 2, 1, 1] 0))

testReplicate0LittleALaborious :: Assertion
testReplicate0LittleALaborious =
  assertEqualUpToEpsilon' 1e-10
    (ringestData [2, 2, 2, 2] [-0.2,-0.2,-0.2,-0.2,25.0003,25.0003,25.0003,25.0003,-0.2,-0.2,-0.2,-0.2,25.0003,25.0003,25.0003,25.0003])
    (rev' @Double @4 conv2dALaborious (rrepl [2, 2, 2, 2] 0))

-- with data t16

testKonst5LittleBLaborious :: Assertion
testKonst5LittleBLaborious =
  assertEqualUpToEpsilon' 1e-8
    (ringestData [2, 2, 2, 2] [18.1,29.1,32.1,40.1,582932.0,582934.99432,582597.1,582625.8943200001,18.1,29.1,32.1,40.1,582932.0,582934.99432,582597.1,582625.8943200001])
    (rev' @Double @4 conv2dBLaborious (rreplicate0N [2, 2, 2, 2] (rscalar 5)))

testKonst5LittleCLaborious :: Assertion
testKonst5LittleCLaborious =
  assertEqualUpToEpsilon' 1e-8
    (ringestData [2, 2, 2, 2] [40.1,8.0,11.0,-3.0,582625.89432,28.79432,-309.09999999999997,25.8,40.1,8.0,11.0,-3.0,582625.89432,28.79432,-309.09999999999997,25.8])
    (rev' @Double @4 conv2dCLaborious (rreplicate0N [2, 2, 2, 2] (rscalar 5)))

testKonst5BigBLaborious :: Assertion
testKonst5BigBLaborious =
  assertEqualUpToEpsilon' 1e-8
    (ringestData [3, 2, 4, 2] [18.1,29.1,32.1,40.1,32.1,40.1,32.1,40.1,582932.0,582934.99432,582597.1,582625.8943200001,582597.1,582625.8943200001,582597.1,582625.8943200001,18.1,29.1,32.1,40.1,32.1,40.1,32.1,40.1,582932.0,582934.99432,582597.1,582625.8943200001,582597.1,582625.8943200001,582597.1,582625.8943200001,18.1,29.1,32.1,40.1,32.1,40.1,32.1,40.1,582932.0,582934.99432,582597.1,582625.8943200001,582597.1,582625.8943200001,582597.1,582625.8943200001])
    (rev' @Double @4 conv2dBLaborious (rreplicate0N [3, 2, 4, 2] (rscalar 5)))

-- The gradient is the same as above, because one argument is the same
-- and convolution is linear.
testKonstNotBigBLaborious :: Assertion
testKonstNotBigBLaborious =
  assertEqualUpToEpsilon' 1e-8
    (ringestData [3, 2, 4, 2] [18.1,29.1,32.1,40.1,32.1,40.1,32.1,40.1,582932.0,582934.99432,582597.1,582625.8943200001,582597.1,582625.8943200001,582597.1,582625.8943200001,18.1,29.1,32.1,40.1,32.1,40.1,32.1,40.1,582932.0,582934.99432,582597.1,582625.8943200001,582597.1,582625.8943200001,582597.1,582625.8943200001,18.1,29.1,32.1,40.1,32.1,40.1,32.1,40.1,582932.0,582934.99432,582597.1,582625.8943200001,582597.1,582625.8943200001,582597.1,582625.8943200001])
    (rev' @Double @4 conv2dBLaborious
          (rfromList0N [3, 2, 4, 2] (map rscalar [37, 36 .. -10])))

testKonst5BigCLaborious :: Assertion
testKonst5BigCLaborious =
  assertEqualUpToEpsilon' 1e-8
    (ringestData [3, 2, 4, 2] [40.1,8.0,11.0,-3.0,0.0,0.0,0.0,0.0,582625.8943200001,28.794320000000003,-309.09999999999997,25.8,0.0,0.0,0.0,0.0,40.1,8.0,11.0,-3.0,0.0,0.0,0.0,0.0,582625.8943200001,28.794320000000003,-309.09999999999997,25.8,0.0,0.0,0.0,0.0,40.1,8.0,11.0,-3.0,0.0,0.0,0.0,0.0,582625.8943200001,28.794320000000003,-309.09999999999997,25.8,0.0,0.0,0.0,0.0])
    (rev' @Double @4 conv2dCLaborious (rreplicate0N [3, 2, 4, 2] (rscalar 5)))

-- The gradient is the same as above, because one argument is the same
-- and convolution is linear.
testKonstNotBigCLaborious :: Assertion
testKonstNotBigCLaborious =
  assertEqualUpToEpsilon' 1e-8
    (ringestData [3, 2, 4, 2] [40.1,8.0,11.0,-3.0,0.0,0.0,0.0,0.0,582625.8943200001,28.794320000000003,-309.09999999999997,25.8,0.0,0.0,0.0,0.0,40.1,8.0,11.0,-3.0,0.0,0.0,0.0,0.0,582625.8943200001,28.794320000000003,-309.09999999999997,25.8,0.0,0.0,0.0,0.0,40.1,8.0,11.0,-3.0,0.0,0.0,0.0,0.0,582625.8943200001,28.794320000000003,-309.09999999999997,25.8,0.0,0.0,0.0,0.0])
    (rev' @Double @4 conv2dCLaborious
          (rfromList0N [3, 2, 4, 2] (map rscalar [37, 36 .. -10])))

-- with data t128b

testKonst5LittleBLaborious128b :: Assertion
testKonst5LittleBLaborious128b =
  assertEqualUpToEpsilon' 1e-8
    (ringestData [2, 2, 2, 2] [112.3003,251.5006,209.49462,482.69492000000014,3.000000000000032,65.90000000000003,164.10000000000002,365.89432000010004,112.3003,251.5006,209.49462,482.69492000000014,3.000000000000032,65.90000000000003,164.10000000000002,365.89432000010004])
    (rev' @Double @4 conv2dBLaborious128b (rreplicate0N [2, 2, 2, 2] (rscalar 5)))

testKonst5LittleCLaborious128b :: Assertion
testKonst5LittleCLaborious128b =
  assertEqualUpToEpsilon' 1e-8
    (ringestData [2, 2, 2, 2] [1627.8210700004993,1571.2321300004994,1132.9261600005002,1188.6375200005,2725.0393200008984,1831.7390200008983,2551.139320000898,1660.8390200008987,1627.8210700004993,1571.2321300004994,1132.9261600005002,1188.6375200005,2725.0393200008984,1831.7390200008983,2551.139320000898,1660.8390200008987])
    (rev' @Double @4 conv2dCLaborious128b (rreplicate0N [2, 2, 2, 2] (rscalar 5)))

{-
testKonst5BigBLaborious128b :: Assertion
testKonst5BigBLaborious128b =
  assertEqualUpToEpsilon' 1e-8
    (ringestData [3, 2, 4, 2] [112.3003,251.5006,209.49462,482.69492000000014,229.49462000000003,610.5892400000002,56.58894000000004,580.6778800001001,3.000000000000032,65.90000000000003,164.10000000000002,365.89432000010004,667.2003000000001,1060.8778800002,893.3003,1465.6665200003993,112.3003,251.5006,209.49462,482.69492000000014,229.49462000000003,610.5892400000002,56.58894000000004,580.6778800001001,3.000000000000032,65.90000000000003,164.10000000000002,365.89432000010004,667.2003000000001,1060.8778800002,893.3003,1465.6665200003993,112.3003,251.5006,209.49462,482.69492000000014,229.49462000000003,610.5892400000002,56.58894000000004,580.6778800001001,3.000000000000032,65.90000000000003,164.10000000000002,365.89432000010004,667.2003000000001,1060.8778800002,893.3003,1465.6665200003993])
    (rev' @Double @4 conv2dBLaborious128b (rreplicate0N [3, 2, 4, 2] (rscalar 5)))

-- The gradient is the same as above, because one argument is the same
-- and convolution is linear.
testKonstNotBigBLaborious128b :: Assertion
testKonstNotBigBLaborious128b =
  assertEqualUpToEpsilon' 1e-8
    (ringestData [3, 2, 4, 2] [112.3003,251.5006,209.49462,482.69492000000014,229.49462000000003,610.5892400000002,56.58894000000004,580.6778800001001,3.000000000000032,65.90000000000003,164.10000000000002,365.89432000010004,667.2003000000001,1060.8778800002,893.3003,1465.6665200003993,112.3003,251.5006,209.49462,482.69492000000014,229.49462000000003,610.5892400000002,56.58894000000004,580.6778800001001,3.000000000000032,65.90000000000003,164.10000000000002,365.89432000010004,667.2003000000001,1060.8778800002,893.3003,1465.6665200003993,112.3003,251.5006,209.49462,482.69492000000014,229.49462000000003,610.5892400000002,56.58894000000004,580.6778800001001,3.000000000000032,65.90000000000003,164.10000000000002,365.89432000010004,667.2003000000001,1060.8778800002,893.3003,1465.6665200003993])
    (rev' @Double @4 conv2dBLaborious128b
          (rfromList0N [3, 2, 4, 2] (map rscalar [37, 36 .. -10])))

testKonst5BigCLaborious128b :: Assertion
testKonst5BigCLaborious128b =
  assertEqualUpToEpsilon' 1e-8
    (ringestData [3, 2, 4, 2] [1627.8210700004993,1571.2321300004994,1132.9261600005002,1188.6375200005,675.7488800003999,828.6545600004001,215.6659200003,388.5716000003,2725.0393200008984,1831.7390200008983,2551.139320000898,1660.8390200008987,1903.750080000699,1174.5497800006997,854.9778800004001,628.8778800004001,1627.8210700004993,1571.2321300004994,1132.9261600005002,1188.6375200005,675.7488800003999,828.6545600004001,215.6659200003,388.5716000003,2725.0393200008984,1831.7390200008983,2551.139320000898,1660.8390200008987,1903.750080000699,1174.5497800006997,854.9778800004001,628.8778800004001,1627.8210700004993,1571.2321300004994,1132.9261600005002,1188.6375200005,675.7488800003999,828.6545600004001,215.6659200003,388.5716000003,2725.0393200008984,1831.7390200008983,2551.139320000898,1660.8390200008987,1903.750080000699,1174.5497800006997,854.9778800004001,628.8778800004001])
    (rev' @Double @4 conv2dCLaborious128b (rreplicate0N [3, 2, 4, 2] (rscalar 5)))

-- The gradient is the same as above, because one argument is the same
-- and convolution is linear.
testKonstNotBigCLaborious128b :: Assertion
testKonstNotBigCLaborious128b =
  assertEqualUpToEpsilon' 1e-8
    (ringestData [3, 2, 4, 2] [1627.8210700004993,1571.2321300004994,1132.9261600005002,1188.6375200005,675.7488800003999,828.6545600004001,215.6659200003,388.5716000003,2725.0393200008984,1831.7390200008983,2551.139320000898,1660.8390200008987,1903.750080000699,1174.5497800006997,854.9778800004001,628.8778800004001,1627.8210700004993,1571.2321300004994,1132.9261600005002,1188.6375200005,675.7488800003999,828.6545600004001,215.6659200003,388.5716000003,2725.0393200008984,1831.7390200008983,2551.139320000898,1660.8390200008987,1903.750080000699,1174.5497800006997,854.9778800004001,628.8778800004001,1627.8210700004993,1571.2321300004994,1132.9261600005002,1188.6375200005,675.7488800003999,828.6545600004001,215.6659200003,388.5716000003,2725.0393200008984,1831.7390200008983,2551.139320000898,1660.8390200008987,1903.750080000699,1174.5497800006997,854.9778800004001,628.8778800004001])
    (rev' @Double @4 conv2dCLaborious128b
          (rfromList0N [3, 2, 4, 2] (map rscalar [37, 36 .. -10])))
-}

-- with data t128c

testKonst5LittleBLaborious128c :: Assertion
testKonst5LittleBLaborious128c =
  assertEqualUpToEpsilon' 1e-8
    (ringestData [2, 2, 2, 2] [54.100300000000004,111.20060000000001,119.09462,270.29492000000005,58.2,140.3,90.4,212.4,54.100300000000004,111.20060000000001,119.09462,270.29492000000005,58.2,140.3,90.4,212.4])
    (rev' @Double @4 conv2dBLaborious128c (rreplicate0N [2, 2, 2, 2] (rscalar 5)))

testKonst5LittleCLaborious128c :: Assertion
testKonst5LittleCLaborious128c =
  assertEqualUpToEpsilon' 1e-8
    (ringestData [2, 2, 2, 2] [2640.8154000007976,1836.3264600007988,2412.414800000798,1662.026160000799,1712.044990000598,1566.644690000599,1445.5506800005985,1358.3503800005992,2640.8154000007976,1836.3264600007988,2412.414800000798,1662.026160000799,1712.044990000598,1566.644690000599,1445.5506800005985,1358.3503800005992])
    (rev' @Double @4 conv2dCLaborious128c (rreplicate0N [2, 2, 2, 2] (rscalar 5)))

{-
testKonst5BigBLaborious128c :: Assertion
testKonst5BigBLaborious128c =
  assertEqualUpToEpsilon' 1e-8
    (ringestData [3, 2, 4, 2] [54.100300000000004,111.20060000000001,119.09462,270.29492000000005,109.09462000000002,318.19492,174.08894000000004,477.28924000000006,58.2,140.3,90.4,212.4,120.4,292.39432000000005,-117.5,103.38864000010005,54.100300000000004,111.20060000000001,119.09462,270.29492000000005,109.09462000000002,318.19492,174.08894000000004,477.28924000000006,58.2,140.3,90.4,212.4,120.4,292.39432000000005,-117.5,103.38864000010005,54.100300000000004,111.20060000000001,119.09462,270.29492000000005,109.09462000000002,318.19492,174.08894000000004,477.28924000000006,58.2,140.3,90.4,212.4,120.4,292.39432000000005,-117.5,103.38864000010005])
    (rev' @Double @4 conv2dBLaborious128c (rreplicate0N [3, 2, 4, 2] (rscalar 5)))

-- The gradient is the same as above, because one argument is the same
-- and convolution is linear.
testKonstNotBigBLaborious128c :: Assertion
testKonstNotBigBLaborious128c =
  assertEqualUpToEpsilon' 1e-8
    (ringestData [3, 2, 4, 2] [54.100300000000004,111.20060000000001,119.09462,270.29492000000005,109.09462000000002,318.19492,174.08894000000004,477.28924000000006,58.2,140.3,90.4,212.4,120.4,292.39432000000005,-117.5,103.38864000010005,54.100300000000004,111.20060000000001,119.09462,270.29492000000005,109.09462000000002,318.19492,174.08894000000004,477.28924000000006,58.2,140.3,90.4,212.4,120.4,292.39432000000005,-117.5,103.38864000010005,54.100300000000004,111.20060000000001,119.09462,270.29492000000005,109.09462000000002,318.19492,174.08894000000004,477.28924000000006,58.2,140.3,90.4,212.4,120.4,292.39432000000005,-117.5,103.38864000010005])
    (rev' @Double @4 conv2dBLaborious128c
          (rfromList0N [3, 2, 4, 2] (map rscalar [37, 36 .. -10])))

testKonst5BigCLaborious128c :: Assertion
testKonst5BigCLaborious128c =
  assertEqualUpToEpsilon' 1e-8
    (ringestData [3, 2, 4, 2] [2640.8154000007976,1836.3264600007988,2412.414800000798,1662.026160000799,2121.6375200006987,1436.2432000006995,1953.5375200006988,1258.1432000006998,1712.044990000598,1566.644690000599,1445.5506800005985,1358.3503800005992,1279.150680000599,1224.1503800005996,987.1677200004992,962.1674200005002,2640.8154000007976,1836.3264600007988,2412.414800000798,1662.026160000799,2121.6375200006987,1436.2432000006995,1953.5375200006988,1258.1432000006998,1712.044990000598,1566.644690000599,1445.5506800005985,1358.3503800005992,1279.150680000599,1224.1503800005996,987.1677200004992,962.1674200005002,2640.8154000007976,1836.3264600007988,2412.414800000798,1662.026160000799,2121.6375200006987,1436.2432000006995,1953.5375200006988,1258.1432000006998,1712.044990000598,1566.644690000599,1445.5506800005985,1358.3503800005992,1279.150680000599,1224.1503800005996,987.1677200004992,962.1674200005002])
    (rev' @Double @4 conv2dCLaborious128c (rreplicate0N [3, 2, 4, 2] (rscalar 5)))

-- The gradient is the same as above, because one argument is the same
-- and convolution is linear.
testKonstNotBigCLaborious128c :: Assertion
testKonstNotBigCLaborious128c =
  assertEqualUpToEpsilon' 1e-8
    (ringestData [3, 2, 4, 2] [2640.8154000007976,1836.3264600007988,2412.414800000798,1662.026160000799,2121.6375200006987,1436.2432000006995,1953.5375200006988,1258.1432000006998,1712.044990000598,1566.644690000599,1445.5506800005985,1358.3503800005992,1279.150680000599,1224.1503800005996,987.1677200004992,962.1674200005002,2640.8154000007976,1836.3264600007988,2412.414800000798,1662.026160000799,2121.6375200006987,1436.2432000006995,1953.5375200006988,1258.1432000006998,1712.044990000598,1566.644690000599,1445.5506800005985,1358.3503800005992,1279.150680000599,1224.1503800005996,987.1677200004992,962.1674200005002,2640.8154000007976,1836.3264600007988,2412.414800000798,1662.026160000799,2121.6375200006987,1436.2432000006995,1953.5375200006988,1258.1432000006998,1712.044990000598,1566.644690000599,1445.5506800005985,1358.3503800005992,1279.150680000599,1224.1503800005996,987.1677200004992,962.1674200005002])
    (rev' @Double @4 conv2dCLaborious128c
          (rfromList0N [3, 2, 4, 2] (map rscalar [37, 36 .. -10])))
-}

-- with data t128b and t128c
{-
testKonst5LittleBLaborious128bc :: Assertion
testKonst5LittleBLaborious128bc =
  assertEqualUpToEpsilon' 1e-8
    (ringestData [2,2,8,4] [112.3003,251.5006,417.79492000000005,494.89491000000015,209.49462,482.69492000000014,778.9778800001002,952.0721900001002,229.49462000000003,610.5892400000002,1113.1722000001,1412.1551500001997,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,3.000000000000032,65.90000000000003,106.90000000000003,173.90000000000006,164.10000000000002,365.89432000010004,593.1946200001,821.2892400002004,667.2003000000001,1060.8778800002,1500.2781800001994,1870.0614400004986,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898,112.3003,251.5006,417.79492000000005,494.89491000000015,209.49462,482.69492000000014,778.9778800001002,952.0721900001002,229.49462000000003,610.5892400000002,1113.1722000001,1412.1551500001997,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,3.000000000000032,65.90000000000003,106.90000000000003,173.90000000000006,164.10000000000002,365.89432000010004,593.1946200001,821.2892400002004,667.2003000000001,1060.8778800002,1500.2781800001994,1870.0614400004986,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898])
    (rev' @Double @4 conv2dBLaborious128b t128c)

testKonst5LittleCLaborious128bc :: Assertion
testKonst5LittleCLaborious128bc =
  assertEqualUpToEpsilon' 1e-8
    (ringestData [2,2,8,4] [1627.8210700004993,1571.2321300004994,1047.1431900004002,393.6715900002,1132.9261600005002,1188.6375200005,803.7488800004002,316.57160000019996,675.7488800003999,828.6545600004001,577.7659200003001,220.57728000019998,215.6659200003,388.5716000003,245.5772800002,94.68864000010001,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,2725.0393200008984,1831.7390200008983,1259.3728000004999,568.6722000005001,2551.139320000898,1660.8390200008987,1151.3728000005,501.6722000005,1903.750080000699,1174.5497800006997,803.9778800004001,340.5775800004001,854.9778800004001,628.8778800004001,450.1892400002,198.8889400002,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1627.8210700004993,1571.2321300004994,1047.1431900004002,393.6715900002,1132.9261600005002,1188.6375200005,803.7488800004002,316.57160000019996,675.7488800003999,828.6545600004001,577.7659200003001,220.57728000019998,215.6659200003,388.5716000003,245.5772800002,94.68864000010001,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,2725.0393200008984,1831.7390200008983,1259.3728000004999,568.6722000005001,2551.139320000898,1660.8390200008987,1151.3728000005,501.6722000005,1903.750080000699,1174.5497800006997,803.9778800004001,340.5775800004001,854.9778800004001,628.8778800004001,450.1892400002,198.8889400002,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0])
    (rev' @Double @4 conv2dCLaborious128b t128c)

testKonst5BigBLaborious128bc :: Assertion
testKonst5BigBLaborious128bc =
  assertEqualUpToEpsilon' 1e-8
    (ringestData [2,2,8,4] [112.3003,251.5006,417.79492000000005,494.89491000000015,209.49462,482.69492000000014,778.9778800001002,952.0721900001002,229.49462000000003,610.5892400000002,1113.1722000001,1412.1551500001997,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,3.000000000000032,65.90000000000003,106.90000000000003,173.90000000000006,164.10000000000002,365.89432000010004,593.1946200001,821.2892400002004,667.2003000000001,1060.8778800002,1500.2781800001994,1870.0614400004986,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898,112.3003,251.5006,417.79492000000005,494.89491000000015,209.49462,482.69492000000014,778.9778800001002,952.0721900001002,229.49462000000003,610.5892400000002,1113.1722000001,1412.1551500001997,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,3.000000000000032,65.90000000000003,106.90000000000003,173.90000000000006,164.10000000000002,365.89432000010004,593.1946200001,821.2892400002004,667.2003000000001,1060.8778800002,1500.2781800001994,1870.0614400004986,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898])
    (rev' @Double @4 conv2dBLaborious128b t128c)

-- The gradient is the same as above, because one argument is the same
-- and convolution is linear.
testKonstNotBigBLaborious128cb :: Assertion
testKonstNotBigBLaborious128cb =
  assertEqualUpToEpsilon' 1e-8
    (ringestData [4,2,4,4] [54.100300000000004,111.20060000000001,191.4006,228.4006,119.09462,270.29492000000005,435.28356000009995,519.1778800001,109.09462000000002,318.19492,563.3835600001,687.2778800001003,174.08894000000004,477.28924000000006,774.2665200002001,931.9551600002003,58.2,140.3,226.39432,266.49431000000004,90.4,212.4,343.69432000000006,432.89431000000013,120.4,292.39432000000005,549.78864,724.8772700001001,-117.5,103.38864000010005,459.88296000009996,695.8659100003002,54.100300000000004,111.20060000000001,191.4006,228.4006,119.09462,270.29492000000005,435.28356000009995,519.1778800001,109.09462000000002,318.19492,563.3835600001,687.2778800001003,174.08894000000004,477.28924000000006,774.2665200002001,931.9551600002003,58.2,140.3,226.39432,266.49431000000004,90.4,212.4,343.69432000000006,432.89431000000013,120.4,292.39432000000005,549.78864,724.8772700001001,-117.5,103.38864000010005,459.88296000009996,695.8659100003002,54.100300000000004,111.20060000000001,191.4006,228.4006,119.09462,270.29492000000005,435.28356000009995,519.1778800001,109.09462000000002,318.19492,563.3835600001,687.2778800001003,174.08894000000004,477.28924000000006,774.2665200002001,931.9551600002003,58.2,140.3,226.39432,266.49431000000004,90.4,212.4,343.69432000000006,432.89431000000013,120.4,292.39432000000005,549.78864,724.8772700001001,-117.5,103.38864000010005,459.88296000009996,695.8659100003002,54.100300000000004,111.20060000000001,191.4006,228.4006,119.09462,270.29492000000005,435.28356000009995,519.1778800001,109.09462000000002,318.19492,563.3835600001,687.2778800001003,174.08894000000004,477.28924000000006,774.2665200002001,931.9551600002003,58.2,140.3,226.39432,266.49431000000004,90.4,212.4,343.69432000000006,432.89431000000013,120.4,292.39432000000005,549.78864,724.8772700001001,-117.5,103.38864000010005,459.88296000009996,695.8659100003002])
    (rev' @Double @4 conv2dBLaborious128c t128b)

testKonst5BigCLaborious128cb :: Assertion
testKonst5BigCLaborious128cb =
  assertEqualUpToEpsilon' 1e-8
    (ringestData [4,2,4,4] [2640.8154000007976,1836.3264600007988,1163.4488800005001,483.7716000003,2412.414800000798,1662.026160000799,1046.2488800005003,446.7716000003,2121.6375200006987,1436.2432000006995,914.5659200004003,399.8772800003,1953.5375200006988,1258.1432000006998,794.3659200004003,359.8772800003,1712.044990000598,1566.644690000599,1143.0671100004001,478.5721900004001,1445.5506800005985,1358.3503800005992,1016.8728000004002,438.47220000040005,1279.150680000599,1224.1503800005996,922.5728000004001,389.3722000004,987.1677200004992,962.1674200005002,710.5841600003,303.48356000030003,2640.8154000007976,1836.3264600007988,1163.4488800005001,483.7716000003,2412.414800000798,1662.026160000799,1046.2488800005003,446.7716000003,2121.6375200006987,1436.2432000006995,914.5659200004003,399.8772800003,1953.5375200006988,1258.1432000006998,794.3659200004003,359.8772800003,1712.044990000598,1566.644690000599,1143.0671100004001,478.5721900004001,1445.5506800005985,1358.3503800005992,1016.8728000004002,438.47220000040005,1279.150680000599,1224.1503800005996,922.5728000004001,389.3722000004,987.1677200004992,962.1674200005002,710.5841600003,303.48356000030003,2640.8154000007976,1836.3264600007988,1163.4488800005001,483.7716000003,2412.414800000798,1662.026160000799,1046.2488800005003,446.7716000003,2121.6375200006987,1436.2432000006995,914.5659200004003,399.8772800003,1953.5375200006988,1258.1432000006998,794.3659200004003,359.8772800003,1712.044990000598,1566.644690000599,1143.0671100004001,478.5721900004001,1445.5506800005985,1358.3503800005992,1016.8728000004002,438.47220000040005,1279.150680000599,1224.1503800005996,922.5728000004001,389.3722000004,987.1677200004992,962.1674200005002,710.5841600003,303.48356000030003,2640.8154000007976,1836.3264600007988,1163.4488800005001,483.7716000003,2412.414800000798,1662.026160000799,1046.2488800005003,446.7716000003,2121.6375200006987,1436.2432000006995,914.5659200004003,399.8772800003,1953.5375200006988,1258.1432000006998,794.3659200004003,359.8772800003,1712.044990000598,1566.644690000599,1143.0671100004001,478.5721900004001,1445.5506800005985,1358.3503800005992,1016.8728000004002,438.47220000040005,1279.150680000599,1224.1503800005996,922.5728000004001,389.3722000004,987.1677200004992,962.1674200005002,710.5841600003,303.48356000030003])
    (rev' @Double @4 conv2dCLaborious128c t128b)

-- The gradient is the same as above, because one argument is the same
-- and convolution is linear.
testKonstNotBigCLaborious128cb :: Assertion
testKonstNotBigCLaborious128cb =
  assertEqualUpToEpsilon' 1e-8
    (ringestData [4,2,4,4] [2640.8154000007976,1836.3264600007988,1163.4488800005001,483.7716000003,2412.414800000798,1662.026160000799,1046.2488800005003,446.7716000003,2121.6375200006987,1436.2432000006995,914.5659200004003,399.8772800003,1953.5375200006988,1258.1432000006998,794.3659200004003,359.8772800003,1712.044990000598,1566.644690000599,1143.0671100004001,478.5721900004001,1445.5506800005985,1358.3503800005992,1016.8728000004002,438.47220000040005,1279.150680000599,1224.1503800005996,922.5728000004001,389.3722000004,987.1677200004992,962.1674200005002,710.5841600003,303.48356000030003,2640.8154000007976,1836.3264600007988,1163.4488800005001,483.7716000003,2412.414800000798,1662.026160000799,1046.2488800005003,446.7716000003,2121.6375200006987,1436.2432000006995,914.5659200004003,399.8772800003,1953.5375200006988,1258.1432000006998,794.3659200004003,359.8772800003,1712.044990000598,1566.644690000599,1143.0671100004001,478.5721900004001,1445.5506800005985,1358.3503800005992,1016.8728000004002,438.47220000040005,1279.150680000599,1224.1503800005996,922.5728000004001,389.3722000004,987.1677200004992,962.1674200005002,710.5841600003,303.48356000030003,2640.8154000007976,1836.3264600007988,1163.4488800005001,483.7716000003,2412.414800000798,1662.026160000799,1046.2488800005003,446.7716000003,2121.6375200006987,1436.2432000006995,914.5659200004003,399.8772800003,1953.5375200006988,1258.1432000006998,794.3659200004003,359.8772800003,1712.044990000598,1566.644690000599,1143.0671100004001,478.5721900004001,1445.5506800005985,1358.3503800005992,1016.8728000004002,438.47220000040005,1279.150680000599,1224.1503800005996,922.5728000004001,389.3722000004,987.1677200004992,962.1674200005002,710.5841600003,303.48356000030003,2640.8154000007976,1836.3264600007988,1163.4488800005001,483.7716000003,2412.414800000798,1662.026160000799,1046.2488800005003,446.7716000003,2121.6375200006987,1436.2432000006995,914.5659200004003,399.8772800003,1953.5375200006988,1258.1432000006998,794.3659200004003,359.8772800003,1712.044990000598,1566.644690000599,1143.0671100004001,478.5721900004001,1445.5506800005985,1358.3503800005992,1016.8728000004002,438.47220000040005,1279.150680000599,1224.1503800005996,922.5728000004001,389.3722000004,987.1677200004992,962.1674200005002,710.5841600003,303.48356000030003])
    (rev' @Double @4 conv2dCLaborious128c t128b)
-}


-- * A padded version (out of bounds indexing is not possible)

-- | Full convolution with just enough extra external zero padding
--   to ensure that the output size is the same as the input size
--   and all input points are read the same number of times.
--
--   The same result could be accomplished by tweaking indexes slightly
--   in conv2dUnpadded, but here additionally all bounds checks in the code
--   are spurious and will be simplified away in the resulting AST program.
conv2dPadded
  :: forall target r. (ADReady target, GoodScalar r)
  => target (TKR 4 r) -> target (TKR 4 r) -> target (TKR 4 r)
conv2dPadded arrK arrA =
  let [nImgs, nCinpA, nAh, nAw] = rshape arrA
      [nCoutK, nCinpK, nKh, nKw] = rshape arrK
      shAPadded = [nImgs, nCinpA, nAh + nKh, nAw + nKw]
      arrAPadded = rbuild @4 @0 @(TKScalar r) @target shAPadded $ \case
        [iImg, iCinp, iPh, iPw] ->
          ifH (iPh <. fromIntegral (nKh `div` 2)
               ||* iPw <. fromIntegral (nKw `div` 2)
               ||* iPh >=. fromIntegral (nAh + nKh `div` 2)
               ||* iPw >=. fromIntegral (nAw + nKw `div` 2))
              (rscalar 0)
              (arrA ! [ iImg
                      , iCinp
                      , iPh - fromIntegral (nKh `div` 2)
                      , iPw - fromIntegral (nKw `div` 2) ])
      nCinp = assert (nCinpA == nCinpK `blame` (nCinpA, nCinpK)) nCinpA
      shB = [nImgs, nCoutK, nAh, nAw]
      shK1 = [1, nCinp, nKh, nKw]
  in rbuild shB $ \case
    [iImg, iCout, iBh, iBw] ->
      let arrAt = slicezL shK1 arrAPadded [iImg, 0, iBh, iBw]
          arrKt = slicezL shK1 arrK [iCout, 0, 0, 0]
      in rdot0 arrAt arrKt
    _ -> error "conv2dPadded: impossible pattern needlessly required"

conv2d1Padded
  :: (ADReady target, GoodScalar r, Differentiable r)
  => target (TKR 4 r) -> target (TKR 4 r)
conv2d1Padded = conv2dPadded $ rconcrete $ Nested.rfromListPrimLinear (fromList [1, 1, 1, 1]) [-0.2]

conv2dAPadded
  :: (ADReady target, GoodScalar r, Differentiable r)
  => target (TKR 4 r) -> target (TKR 4 r)
conv2dAPadded =
  conv2dPadded $ rconcrete $ Nested.rfromListPrimLinear (fromList [1, 2, 1, 1]) [-0.2, 25.0003]

conv2dBPadded
  :: (ADReady target, GoodScalar r, Differentiable r)
  => target (TKR 4 r) -> target (TKR 4 r)
conv2dBPadded = conv2dPadded (rconcrete $ unConcrete t16b)

conv2dCPadded
  :: (ADReady target, GoodScalar r, Differentiable r)
  => target (TKR 4 r) -> target (TKR 4 r)
conv2dCPadded = flip conv2dPadded (rconcrete $ unConcrete t16b)

conv2dBPadded128b
  :: (ADReady target, GoodScalar r, Differentiable r)
  => target (TKR 4 r) -> target (TKR 4 r)
conv2dBPadded128b = conv2dPadded (rconcrete $ unConcrete t128b)

conv2dCPadded128b
  :: (ADReady target, GoodScalar r, Differentiable r)
  => target (TKR 4 r) -> target (TKR 4 r)
conv2dCPadded128b = flip conv2dPadded (rconcrete $ unConcrete t128b)

_conv2dBPadded128c
  :: (ADReady target, GoodScalar r, Differentiable r)
  => target (TKR 4 r) -> target (TKR 4 r)
_conv2dBPadded128c = conv2dPadded (rconcrete $ unConcrete t128c)

_conv2dCPadded128c
  :: (ADReady target, GoodScalar r, Differentiable r)
  => target (TKR 4 r) -> target (TKR 4 r)
_conv2dCPadded128c = flip conv2dPadded (rconcrete $ unConcrete t128c)

-- TODO: OOMs
_testReplicate0RevPadded :: Assertion
_testReplicate0RevPadded =
  assertEqualUpToEpsilon 1e-4
    (rconcrete $ Nested.rfromListPrimLinear [2, 2, 2, 2] [40.1,8.0,11.0,-3.0,582625.89432,28.79432,-309.09999999999997,25.8,40.1,8.0,11.0,-3.0,582625.89432,28.79432,-309.09999999999997,25.8])
    (grad (kfromR . rsum0 @4 @(TKScalar Double) . conv2dBPadded) (rrepl [2, 2, 2, 2] 0))

testReplicate0Tiny1Padded :: Assertion
testReplicate0Tiny1Padded =
  assertEqualUpToEpsilon 1e-10
    (ringestData [1, 1, 1, 1] [-0.2])
    (cgrad (kfromR . rsum0 @4 @(TKScalar Double) . conv2d1Padded) (rrepl [1, 1, 1, 1] 0))

testReplicate0TinySPadded :: Assertion
testReplicate0TinySPadded =
  assertEqualUpToEpsilon 1e-10
    (ringestData [1, 1, 1, 1] [582665.99432])
    (grad (kfromR . rsum0 @4 @(TKScalar Double) .
          (conv2dPadded $ rreplicate0N [1, 1, 1, 1] (rsum0 (rconcrete $ unConcrete t16b))))
          (ringestData [1, 1, 1, 1] [0]))

testReplicate0TinyAPadded :: Assertion
testReplicate0TinyAPadded =
  assertEqualUpToEpsilon 1e-10
    (ringestData [1, 2, 1, 1] [-0.2,25.0003])
    (cgrad (kfromR . rsum0 @4 @(TKScalar Double) . conv2dAPadded) (rrepl [1, 2, 1, 1] 0))

testReplicate0LittleAPadded :: Assertion
testReplicate0LittleAPadded =
  assertEqualUpToEpsilon 1e-10
    (ringestData [2, 2, 2, 2] [-0.2,-0.2,-0.2,-0.2,25.0003,25.0003,25.0003,25.0003,-0.2,-0.2,-0.2,-0.2,25.0003,25.0003,25.0003,25.0003])
    (grad (kfromR . rsum0 @4 @(TKScalar Double) . conv2dAPadded) (rrepl [2, 2, 2, 2] 0))

-- with data t16

-- TODO: OOMs
_testKonst5LittleBPadded :: Assertion
_testKonst5LittleBPadded =
  assertEqualUpToEpsilon 1e-8
    (ringestData [2, 2, 2, 2] [40.1,8.0,11.0,-3.0,582625.8943200001,28.794320000000003,-309.09999999999997,25.8,40.1,8.0,11.0,-3.0,582625.8943200001,28.794320000000003,-309.09999999999997,25.8])
    (grad (kfromR . rsum0 @4 @(TKScalar Double) . conv2dBPadded) (rreplicate0N [2, 2, 2, 2] (rscalar 5)))

-- TODO: OOMs
_testKonst5LittleCPadded :: Assertion
_testKonst5LittleCPadded =
  assertEqualUpToEpsilon 1e-8
    (ringestData [2, 2, 2, 2] [18.1,29.1,32.1,40.1,582932.0,582934.99432,582597.1,582625.8943200001,18.1,29.1,32.1,40.1,582932.0,582934.99432,582597.1,582625.8943200001])
    (grad (kfromR . rsum0 @4 @(TKScalar Double) . conv2dCPadded) (rreplicate0N [2, 2, 2, 2] (rscalar 5)))

-- TODO: OOMs
_testKonst5BigBPadded :: Assertion
_testKonst5BigBPadded =
  assertEqualUpToEpsilon 1e-8
    (ringestData [3, 2, 4, 2] [40.1,8.0,40.1,8.0,40.1,8.0,11.0,-3.0,582625.8943200001,28.794320000000003,582625.8943200001,28.794320000000003,582625.8943200001,28.794320000000003,-309.09999999999997,25.8,40.1,8.0,40.1,8.0,40.1,8.0,11.0,-3.0,582625.8943200001,28.794320000000003,582625.8943200001,28.794320000000003,582625.8943200001,28.794320000000003,-309.09999999999997,25.8,40.1,8.0,40.1,8.0,40.1,8.0,11.0,-3.0,582625.8943200001,28.794320000000003,582625.8943200001,28.794320000000003,582625.8943200001,28.794320000000003,-309.09999999999997,25.8])
    (grad (kfromR . rsum0 @4 @(TKScalar Double) . conv2dBPadded) (rreplicate0N [3, 2, 4, 2] (rscalar 5)))

-- TODO: OOMs
-- The gradient is the same as above, because one argument is the same
-- and convolution is linear.
_testKonstNotBigBPadded :: Assertion
_testKonstNotBigBPadded =
  assertEqualUpToEpsilon 1e-8
    (ringestData [3, 2, 4, 2] [40.1,8.0,40.1,8.0,40.1,8.0,11.0,-3.0,582625.8943200001,28.794320000000003,582625.8943200001,28.794320000000003,582625.8943200001,28.794320000000003,-309.09999999999997,25.8,40.1,8.0,40.1,8.0,40.1,8.0,11.0,-3.0,582625.8943200001,28.794320000000003,582625.8943200001,28.794320000000003,582625.8943200001,28.794320000000003,-309.09999999999997,25.8,40.1,8.0,40.1,8.0,40.1,8.0,11.0,-3.0,582625.8943200001,28.794320000000003,582625.8943200001,28.794320000000003,582625.8943200001,28.794320000000003,-309.09999999999997,25.8])
    (grad (kfromR . rsum0 @4 @(TKScalar Double) . conv2dBPadded)
          (rfromList0N [3, 2, 4, 2] (map rscalar [37, 36 .. -10])))

-- TODO: OOMs
_testKonst5BigCPadded :: Assertion
_testKonst5BigCPadded =
  assertEqualUpToEpsilon 1e-8
    (ringestData [3, 2, 4, 2] [0.0,0.0,18.1,29.1,32.1,40.1,14.0,11.0,0.0,0.0,582932.0,582934.99432,582597.1,582625.8943200001,-334.9,-309.09999999999997,0.0,0.0,18.1,29.1,32.1,40.1,14.0,11.0,0.0,0.0,582932.0,582934.99432,582597.1,582625.8943200001,-334.9,-309.09999999999997,0.0,0.0,18.1,29.1,32.1,40.1,14.0,11.0,0.0,0.0,582932.0,582934.99432,582597.1,582625.8943200001,-334.9,-309.09999999999997])
    (grad (kfromR . rsum0 @4 @(TKScalar Double) . conv2dCPadded) (rreplicate0N [3, 2, 4, 2] (rscalar 5)))

-- TODO: OOMs
-- The gradient is the same as above, because one argument is the same
-- and convolution is linear.
_testKonstNotBigCPadded :: Assertion
_testKonstNotBigCPadded =
  assertEqualUpToEpsilon 1e-8
    (ringestData [3, 2, 4, 2] [0.0,0.0,18.1,29.1,32.1,40.1,14.0,11.0,0.0,0.0,582932.0,582934.99432,582597.1,582625.8943200001,-334.9,-309.09999999999997,0.0,0.0,18.1,29.1,32.1,40.1,14.0,11.0,0.0,0.0,582932.0,582934.99432,582597.1,582625.8943200001,-334.9,-309.09999999999997,0.0,0.0,18.1,29.1,32.1,40.1,14.0,11.0,0.0,0.0,582932.0,582934.99432,582597.1,582625.8943200001,-334.9,-309.09999999999997])
    (grad (kfromR . rsum0 @4 @(TKScalar Double) . conv2dCPadded)
          (rfromList0N [3, 2, 4, 2] (map rscalar [37, 36 .. -10])))

-- with data t128b

-- TODO: OOMs
_testKonst5LittleBPadded128b :: Assertion
_testKonst5LittleBPadded128b =
  assertEqualUpToEpsilon 1e-8
    (ringestData [2, 2, 2, 2] [578.1829600001,558.1716000002,608.0772800002001,577.7659200003001,729.1778800002002,701.1835600003001,833.9722000003002,803.9778800004001,578.1829600001,558.1716000002,608.0772800002001,577.7659200003001,729.1778800002002,701.1835600003001,833.9722000003002,803.9778800004001])
    (grad (kfromR . rsum0 @4 @(TKScalar Double) . conv2dBPadded128b) (rreplicate0N [2, 2, 2, 2] (rscalar 5)))

-- TODO: OOMs
_testKonst5LittleCPadded128b :: Assertion
_testKonst5LittleCPadded128b =
  assertEqualUpToEpsilon 1e-8
    (ringestData [2, 2, 2, 2] [1113.1722000001,1412.1551500001997,1234.1494800003002,1627.8210700004993,1500.2781800001994,1870.0614400004986,2156.3671200003987,2725.0393200008984,1113.1722000001,1412.1551500001997,1234.1494800003002,1627.8210700004993,1500.2781800001994,1870.0614400004986,2156.3671200003987,2725.0393200008984])
    (grad (kfromR . rsum0 @4 @(TKScalar Double) . conv2dCPadded128b) (rreplicate0N [2, 2, 2, 2] (rscalar 5)))

{-
testKonst5BigBPadded128b :: Assertion
testKonst5BigBPadded128b =
  assertEqualUpToEpsilon' 1e-8
    (ringestData [3, 2, 4, 2] [112.3003,251.5006,209.49462,482.69492000000014,229.49462000000003,610.5892400000002,56.58894000000004,580.6778800001001,3.000000000000032,65.90000000000003,164.10000000000002,365.89432000010004,667.2003000000001,1060.8778800002,893.3003,1465.6665200003993,112.3003,251.5006,209.49462,482.69492000000014,229.49462000000003,610.5892400000002,56.58894000000004,580.6778800001001,3.000000000000032,65.90000000000003,164.10000000000002,365.89432000010004,667.2003000000001,1060.8778800002,893.3003,1465.6665200003993,112.3003,251.5006,209.49462,482.69492000000014,229.49462000000003,610.5892400000002,56.58894000000004,580.6778800001001,3.000000000000032,65.90000000000003,164.10000000000002,365.89432000010004,667.2003000000001,1060.8778800002,893.3003,1465.6665200003993])
    (rev' @Double @4 conv2dBPadded128b (rreplicate0N [3, 2, 4, 2] (rscalar 5)))

-- The gradient is the same as above, because one argument is the same
-- and convolution is linear.
testKonstNotBigBPadded128b :: Assertion
testKonstNotBigBPadded128b =
  assertEqualUpToEpsilon' 1e-8
    (ringestData [3, 2, 4, 2] [112.3003,251.5006,209.49462,482.69492000000014,229.49462000000003,610.5892400000002,56.58894000000004,580.6778800001001,3.000000000000032,65.90000000000003,164.10000000000002,365.89432000010004,667.2003000000001,1060.8778800002,893.3003,1465.6665200003993,112.3003,251.5006,209.49462,482.69492000000014,229.49462000000003,610.5892400000002,56.58894000000004,580.6778800001001,3.000000000000032,65.90000000000003,164.10000000000002,365.89432000010004,667.2003000000001,1060.8778800002,893.3003,1465.6665200003993,112.3003,251.5006,209.49462,482.69492000000014,229.49462000000003,610.5892400000002,56.58894000000004,580.6778800001001,3.000000000000032,65.90000000000003,164.10000000000002,365.89432000010004,667.2003000000001,1060.8778800002,893.3003,1465.6665200003993])
    (rev' @Double @4 conv2dBPadded128b
          (rfromList0N [3, 2, 4, 2] (map rscalar [37, 36 .. -10])))

testKonst5BigCPadded128b :: Assertion
testKonst5BigCPadded128b =
  assertEqualUpToEpsilon' 1e-8
    (ringestData [3, 2, 4, 2] [1627.8210700004993,1571.2321300004994,1132.9261600005002,1188.6375200005,675.7488800003999,828.6545600004001,215.6659200003,388.5716000003,2725.0393200008984,1831.7390200008983,2551.139320000898,1660.8390200008987,1903.750080000699,1174.5497800006997,854.9778800004001,628.8778800004001,1627.8210700004993,1571.2321300004994,1132.9261600005002,1188.6375200005,675.7488800003999,828.6545600004001,215.6659200003,388.5716000003,2725.0393200008984,1831.7390200008983,2551.139320000898,1660.8390200008987,1903.750080000699,1174.5497800006997,854.9778800004001,628.8778800004001,1627.8210700004993,1571.2321300004994,1132.9261600005002,1188.6375200005,675.7488800003999,828.6545600004001,215.6659200003,388.5716000003,2725.0393200008984,1831.7390200008983,2551.139320000898,1660.8390200008987,1903.750080000699,1174.5497800006997,854.9778800004001,628.8778800004001])
    (rev' @Double @4 conv2dCPadded128b (rreplicate0N [3, 2, 4, 2] (rscalar 5)))

-- The gradient is the same as above, because one argument is the same
-- and convolution is linear.
testKonstNotBigCPadded128b :: Assertion
testKonstNotBigCPadded128b =
  assertEqualUpToEpsilon' 1e-8
    (ringestData [3, 2, 4, 2] [1627.8210700004993,1571.2321300004994,1132.9261600005002,1188.6375200005,675.7488800003999,828.6545600004001,215.6659200003,388.5716000003,2725.0393200008984,1831.7390200008983,2551.139320000898,1660.8390200008987,1903.750080000699,1174.5497800006997,854.9778800004001,628.8778800004001,1627.8210700004993,1571.2321300004994,1132.9261600005002,1188.6375200005,675.7488800003999,828.6545600004001,215.6659200003,388.5716000003,2725.0393200008984,1831.7390200008983,2551.139320000898,1660.8390200008987,1903.750080000699,1174.5497800006997,854.9778800004001,628.8778800004001,1627.8210700004993,1571.2321300004994,1132.9261600005002,1188.6375200005,675.7488800003999,828.6545600004001,215.6659200003,388.5716000003,2725.0393200008984,1831.7390200008983,2551.139320000898,1660.8390200008987,1903.750080000699,1174.5497800006997,854.9778800004001,628.8778800004001])
    (rev' @Double @4 conv2dCPadded128b
          (rfromList0N [3, 2, 4, 2] (map rscalar [37, 36 .. -10])))

-- with data t128c

testKonst5LittleBPadded128c :: Assertion
testKonst5LittleBPadded128c =
  assertEqualUpToEpsilon' 1e-8
    (ringestData [2, 2, 2, 2] [186.7886400001,121.7829600001,269.09432000009997,261.3943200001,210.9943200001,231.79432000010002,160.00030000000004,194.00060000000005,186.7886400001,121.7829600001,269.09432000009997,261.3943200001,210.9943200001,231.79432000010002,160.00030000000004,194.00060000000005])
    (rev' @Double @4 conv2dBPadded128c (rreplicate0N [2, 2, 2, 2] (rscalar 5)))

testKonst5LittleCPadded128c :: Assertion
testKonst5LittleCPadded128c =
  assertEqualUpToEpsilon' 1e-8
    (ringestData [2, 2, 2, 2] [1772.649480000399,2138.4267600005987,2157.0438000004983,2640.8154000007976,961.7781800001002,1359.4557500003987,1233.4728000001987,1712.044990000598,1772.649480000399,2138.4267600005987,2157.0438000004983,2640.8154000007976,961.7781800001002,1359.4557500003987,1233.4728000001987,1712.044990000598])
    (rev' @Double @4 conv2dCPadded128c (rreplicate0N [2, 2, 2, 2] (rscalar 5)))

testKonst5BigBPadded128c :: Assertion
testKonst5BigBPadded128c =
  assertEqualUpToEpsilon' 1e-8
    (ringestData [3, 2, 4, 2] [54.100300000000004,111.20060000000001,119.09462,270.29492000000005,109.09462000000002,318.19492,174.08894000000004,477.28924000000006,58.2,140.3,90.4,212.4,120.4,292.39432000000005,-117.5,103.38864000010005,54.100300000000004,111.20060000000001,119.09462,270.29492000000005,109.09462000000002,318.19492,174.08894000000004,477.28924000000006,58.2,140.3,90.4,212.4,120.4,292.39432000000005,-117.5,103.38864000010005,54.100300000000004,111.20060000000001,119.09462,270.29492000000005,109.09462000000002,318.19492,174.08894000000004,477.28924000000006,58.2,140.3,90.4,212.4,120.4,292.39432000000005,-117.5,103.38864000010005])
    (rev' @Double @4 conv2dBPadded128c (rreplicate0N [3, 2, 4, 2] (rscalar 5)))

-- The gradient is the same as above, because one argument is the same
-- and convolution is linear.
testKonstNotBigBPadded128c :: Assertion
testKonstNotBigBPadded128c =
  assertEqualUpToEpsilon' 1e-8
    (ringestData [3, 2, 4, 2] [54.100300000000004,111.20060000000001,119.09462,270.29492000000005,109.09462000000002,318.19492,174.08894000000004,477.28924000000006,58.2,140.3,90.4,212.4,120.4,292.39432000000005,-117.5,103.38864000010005,54.100300000000004,111.20060000000001,119.09462,270.29492000000005,109.09462000000002,318.19492,174.08894000000004,477.28924000000006,58.2,140.3,90.4,212.4,120.4,292.39432000000005,-117.5,103.38864000010005,54.100300000000004,111.20060000000001,119.09462,270.29492000000005,109.09462000000002,318.19492,174.08894000000004,477.28924000000006,58.2,140.3,90.4,212.4,120.4,292.39432000000005,-117.5,103.38864000010005])
    (rev' @Double @4 conv2dBPadded128c
          (rfromList0N [3, 2, 4, 2] (map rscalar [37, 36 .. -10])))

testKonst5BigCPadded128c :: Assertion
testKonst5BigCPadded128c =
  assertEqualUpToEpsilon' 1e-8
    (ringestData [3, 2, 4, 2] [2640.8154000007976,1836.3264600007988,2412.414800000798,1662.026160000799,2121.6375200006987,1436.2432000006995,1953.5375200006988,1258.1432000006998,1712.044990000598,1566.644690000599,1445.5506800005985,1358.3503800005992,1279.150680000599,1224.1503800005996,987.1677200004992,962.1674200005002,2640.8154000007976,1836.3264600007988,2412.414800000798,1662.026160000799,2121.6375200006987,1436.2432000006995,1953.5375200006988,1258.1432000006998,1712.044990000598,1566.644690000599,1445.5506800005985,1358.3503800005992,1279.150680000599,1224.1503800005996,987.1677200004992,962.1674200005002,2640.8154000007976,1836.3264600007988,2412.414800000798,1662.026160000799,2121.6375200006987,1436.2432000006995,1953.5375200006988,1258.1432000006998,1712.044990000598,1566.644690000599,1445.5506800005985,1358.3503800005992,1279.150680000599,1224.1503800005996,987.1677200004992,962.1674200005002])
    (rev' @Double @4 conv2dCPadded128c (rreplicate0N [3, 2, 4, 2] (rscalar 5)))

-- The gradient is the same as above, because one argument is the same
-- and convolution is linear.
testKonstNotBigCPadded128c :: Assertion
testKonstNotBigCPadded128c =
  assertEqualUpToEpsilon' 1e-8
    (ringestData [3, 2, 4, 2] [2640.8154000007976,1836.3264600007988,2412.414800000798,1662.026160000799,2121.6375200006987,1436.2432000006995,1953.5375200006988,1258.1432000006998,1712.044990000598,1566.644690000599,1445.5506800005985,1358.3503800005992,1279.150680000599,1224.1503800005996,987.1677200004992,962.1674200005002,2640.8154000007976,1836.3264600007988,2412.414800000798,1662.026160000799,2121.6375200006987,1436.2432000006995,1953.5375200006988,1258.1432000006998,1712.044990000598,1566.644690000599,1445.5506800005985,1358.3503800005992,1279.150680000599,1224.1503800005996,987.1677200004992,962.1674200005002,2640.8154000007976,1836.3264600007988,2412.414800000798,1662.026160000799,2121.6375200006987,1436.2432000006995,1953.5375200006988,1258.1432000006998,1712.044990000598,1566.644690000599,1445.5506800005985,1358.3503800005992,1279.150680000599,1224.1503800005996,987.1677200004992,962.1674200005002])
    (rev' @Double @4 conv2dCPadded128c
          (rfromList0N [3, 2, 4, 2] (map rscalar [37, 36 .. -10])))
-}

-- with data t128b and t128c
{-
testKonst5LittleBPadded128bc :: Assertion
testKonst5LittleBPadded128bc =
  assertEqualUpToEpsilon' 1e-8
    (ringestData [2,2,8,4] [1113.1722000001,1412.1551500001997,1182.6605300001997,801.5659100002002,1234.1494800003,1627.8210700004993,1571.2321300004992,1047.1431900004,1234.1494800003,1627.8210700004993,1571.2321300004992,1047.1431900004,1234.1494800003,1627.8210700004993,1571.2321300004992,1047.1431900004,1234.1494800003,1627.8210700004993,1571.2321300004992,1047.1431900004,1234.1494800003,1627.8210700004993,1571.2321300004992,1047.1431900004,816.3545600003002,1132.9261600005,1188.6375200004998,803.7488800004002,455.17160000019993,675.7488800004002,828.6545600004,577.7659200003001,1500.2781800001994,1870.0614400004986,1202.8611400004997,809.1835600003001,2156.3671200003987,2725.039320000898,1831.7390200008986,1259.3728000004996,2156.3671200003987,2725.039320000898,1831.7390200008986,1259.3728000004996,2156.3671200003987,2725.039320000898,1831.7390200008986,1259.3728000004996,2156.3671200003987,2725.039320000898,1831.7390200008986,1259.3728000004996,2156.3671200003987,2725.039320000898,1831.7390200008986,1259.3728000004996,2049.4671200003986,2551.139320000898,1660.8390200008987,1151.3728000004999,1563.172500000299,1903.7500800006983,1174.5497800006997,803.9778800004001,1113.1722000001,1412.1551500001997,1182.6605300001997,801.5659100002002,1234.1494800003,1627.8210700004993,1571.2321300004992,1047.1431900004,1234.1494800003,1627.8210700004993,1571.2321300004992,1047.1431900004,1234.1494800003,1627.8210700004993,1571.2321300004992,1047.1431900004,1234.1494800003,1627.8210700004993,1571.2321300004992,1047.1431900004,1234.1494800003,1627.8210700004993,1571.2321300004992,1047.1431900004,816.3545600003002,1132.9261600005,1188.6375200004998,803.7488800004002,455.17160000019993,675.7488800004002,828.6545600004,577.7659200003001,1500.2781800001994,1870.0614400004986,1202.8611400004997,809.1835600003001,2156.3671200003987,2725.039320000898,1831.7390200008986,1259.3728000004996,2156.3671200003987,2725.039320000898,1831.7390200008986,1259.3728000004996,2156.3671200003987,2725.039320000898,1831.7390200008986,1259.3728000004996,2156.3671200003987,2725.039320000898,1831.7390200008986,1259.3728000004996,2156.3671200003987,2725.039320000898,1831.7390200008986,1259.3728000004996,2049.4671200003986,2551.139320000898,1660.8390200008987,1151.3728000004999,1563.172500000299,1903.7500800006983,1174.5497800006997,803.9778800004001])
    (rev' @Double @4 conv2dBPadded128b t128c)

testKonst5LittleCPadded128bc :: Assertion
testKonst5LittleCPadded128bc =
  assertEqualUpToEpsilon' 1e-8
    (ringestData [2,2,8,4] [0.0,0.0,0.0,0.0,251.5006,417.7949200000001,494.8949100000001,382.59461000000005,482.69492000000014,778.9778800001002,952.0721900001002,742.5775700001002,610.5892400000002,1113.1722000001,1412.1551500001997,1182.6605300002002,580.6778800001001,1234.1494800003002,1627.8210700004993,1571.2321300004994,329.17728000010004,816.3545600003002,1132.9261600005002,1188.6375200005,97.98296000010004,455.17160000020016,675.7488800003999,828.6545600004001,-29.9113599999,120.97728000019995,215.6659200003,388.5716000003,0.0,0.0,0.0,0.0,65.90000000000003,106.90000000000003,173.90000000000006,170.90000000000003,365.89432000010004,593.1946200001,821.2892400002003,657.1892400002001,1060.8778800002,1500.2781800001994,1870.0614400004986,1202.8611400005,1465.6665200003995,2156.3671200003987,2725.0393200008984,1831.7390200008983,1399.7665200003996,2049.4671200003986,2551.139320000898,1660.8390200008987,1099.7722000003,1563.1725000002994,1903.750080000699,1174.5497800006997,404.7886400002001,656.0889400002,854.9778800004001,628.8778800004001,0.0,0.0,0.0,0.0,251.5006,417.7949200000001,494.8949100000001,382.59461000000005,482.69492000000014,778.9778800001002,952.0721900001002,742.5775700001002,610.5892400000002,1113.1722000001,1412.1551500001997,1182.6605300002002,580.6778800001001,1234.1494800003002,1627.8210700004993,1571.2321300004994,329.17728000010004,816.3545600003002,1132.9261600005002,1188.6375200005,97.98296000010004,455.17160000020016,675.7488800003999,828.6545600004001,-29.9113599999,120.97728000019995,215.6659200003,388.5716000003,0.0,0.0,0.0,0.0,65.90000000000003,106.90000000000003,173.90000000000006,170.90000000000003,365.89432000010004,593.1946200001,821.2892400002003,657.1892400002001,1060.8778800002,1500.2781800001994,1870.0614400004986,1202.8611400005,1465.6665200003995,2156.3671200003987,2725.0393200008984,1831.7390200008983,1399.7665200003996,2049.4671200003986,2551.139320000898,1660.8390200008987,1099.7722000003,1563.1725000002994,1903.750080000699,1174.5497800006997,404.7886400002001,656.0889400002,854.9778800004001,628.8778800004001])
    (rev' @Double @4 conv2dCPadded128b t128c)

testKonst5BigBPadded128bc :: Assertion
testKonst5BigBPadded128bc =
  assertEqualUpToEpsilon' 1e-8
    (ringestData [2,2,8,4] [1113.1722000001,1412.1551500001997,1182.6605300001997,801.5659100002002,1234.1494800003,1627.8210700004993,1571.2321300004992,1047.1431900004,1234.1494800003,1627.8210700004993,1571.2321300004992,1047.1431900004,1234.1494800003,1627.8210700004993,1571.2321300004992,1047.1431900004,1234.1494800003,1627.8210700004993,1571.2321300004992,1047.1431900004,1234.1494800003,1627.8210700004993,1571.2321300004992,1047.1431900004,816.3545600003002,1132.9261600005,1188.6375200004998,803.7488800004002,455.17160000019993,675.7488800004002,828.6545600004,577.7659200003001,1500.2781800001994,1870.0614400004986,1202.8611400004997,809.1835600003001,2156.3671200003987,2725.039320000898,1831.7390200008986,1259.3728000004996,2156.3671200003987,2725.039320000898,1831.7390200008986,1259.3728000004996,2156.3671200003987,2725.039320000898,1831.7390200008986,1259.3728000004996,2156.3671200003987,2725.039320000898,1831.7390200008986,1259.3728000004996,2156.3671200003987,2725.039320000898,1831.7390200008986,1259.3728000004996,2049.4671200003986,2551.139320000898,1660.8390200008987,1151.3728000004999,1563.172500000299,1903.7500800006983,1174.5497800006997,803.9778800004001,1113.1722000001,1412.1551500001997,1182.6605300001997,801.5659100002002,1234.1494800003,1627.8210700004993,1571.2321300004992,1047.1431900004,1234.1494800003,1627.8210700004993,1571.2321300004992,1047.1431900004,1234.1494800003,1627.8210700004993,1571.2321300004992,1047.1431900004,1234.1494800003,1627.8210700004993,1571.2321300004992,1047.1431900004,1234.1494800003,1627.8210700004993,1571.2321300004992,1047.1431900004,816.3545600003002,1132.9261600005,1188.6375200004998,803.7488800004002,455.17160000019993,675.7488800004002,828.6545600004,577.7659200003001,1500.2781800001994,1870.0614400004986,1202.8611400004997,809.1835600003001,2156.3671200003987,2725.039320000898,1831.7390200008986,1259.3728000004996,2156.3671200003987,2725.039320000898,1831.7390200008986,1259.3728000004996,2156.3671200003987,2725.039320000898,1831.7390200008986,1259.3728000004996,2156.3671200003987,2725.039320000898,1831.7390200008986,1259.3728000004996,2156.3671200003987,2725.039320000898,1831.7390200008986,1259.3728000004996,2049.4671200003986,2551.139320000898,1660.8390200008987,1151.3728000004999,1563.172500000299,1903.7500800006983,1174.5497800006997,803.9778800004001])
    (rev' @Double @4 conv2dBPadded128b t128c)

-- The gradient is the same as above, because one argument is the same
-- and convolution is linear.
testKonstNotBigBPadded128cb :: Assertion
testKonstNotBigBPadded128cb =
  assertEqualUpToEpsilon' 1e-8
    (ringestData [4,2,4,4] [606.6659200002001,754.4545600002001,651.5659200002001,373.6659200002,720.1772800002001,917.0659200003003,749.9716000003002,467.1772800001999,1209.2659200003,1451.1488800004995,884.9545600005001,547.1716000002999,1382.7772800002997,1708.860240000599,1078.4602400006002,708.7829600003,316.58864000010004,552.3716000003001,707.9716000003,538.0829600002,328.18894000010005,579.9722000003001,735.8722000002999,565.9835600002,411.9895400001,634.5784800003,706.4781800003,507.58924000020005,773.5898400001,1016.1790800003,753.2787800002999,550.5898400002,606.6659200002001,754.4545600002001,651.5659200002001,373.6659200002,720.1772800002001,917.0659200003003,749.9716000003002,467.1772800001999,1209.2659200003,1451.1488800004995,884.9545600005001,547.1716000002999,1382.7772800002997,1708.860240000599,1078.4602400006002,708.7829600003,316.58864000010004,552.3716000003001,707.9716000003,538.0829600002,328.18894000010005,579.9722000003001,735.8722000002999,565.9835600002,411.9895400001,634.5784800003,706.4781800003,507.58924000020005,773.5898400001,1016.1790800003,753.2787800002999,550.5898400002,606.6659200002001,754.4545600002001,651.5659200002001,373.6659200002,720.1772800002001,917.0659200003003,749.9716000003002,467.1772800001999,1209.2659200003,1451.1488800004995,884.9545600005001,547.1716000002999,1382.7772800002997,1708.860240000599,1078.4602400006002,708.7829600003,316.58864000010004,552.3716000003001,707.9716000003,538.0829600002,328.18894000010005,579.9722000003001,735.8722000002999,565.9835600002,411.9895400001,634.5784800003,706.4781800003,507.58924000020005,773.5898400001,1016.1790800003,753.2787800002999,550.5898400002,606.6659200002001,754.4545600002001,651.5659200002001,373.6659200002,720.1772800002001,917.0659200003003,749.9716000003002,467.1772800001999,1209.2659200003,1451.1488800004995,884.9545600005001,547.1716000002999,1382.7772800002997,1708.860240000599,1078.4602400006002,708.7829600003,316.58864000010004,552.3716000003001,707.9716000003,538.0829600002,328.18894000010005,579.9722000003001,735.8722000002999,565.9835600002,411.9895400001,634.5784800003,706.4781800003,507.58924000020005,773.5898400001,1016.1790800003,753.2787800002999,550.5898400002])
    (rev' @Double @4 conv2dBPadded128c t128b)

testKonst5BigCPadded128cb :: Assertion
testKonst5BigCPadded128cb =
  assertEqualUpToEpsilon' 1e-8
    (ringestData [4,2,4,4] [720.1835600001002,1155.4608400002999,1436.2438000003995,1150.0548600004,1222.1722000002,1772.649480000399,2138.4267600005987,1463.1378200005997,1477.3665200002997,2157.0438000004983,2640.8154000007976,1836.3264600007988,1366.1659200002998,1965.6432000004988,2412.414800000798,1662.026160000799,226.3886400001001,671.8832600001001,1012.8665100002999,1078.3665100003,419.3835600001001,961.7781800001002,1359.4557500003987,1310.9554500003997,568.9778800002001,1233.4728000001987,1712.044990000598,1566.644690000599,428.6778800002001,1007.0784800001993,1445.5506800005985,1358.3503800005992,720.1835600001002,1155.4608400002999,1436.2438000003995,1150.0548600004,1222.1722000002,1772.649480000399,2138.4267600005987,1463.1378200005997,1477.3665200002997,2157.0438000004983,2640.8154000007976,1836.3264600007988,1366.1659200002998,1965.6432000004988,2412.414800000798,1662.026160000799,226.3886400001001,671.8832600001001,1012.8665100002999,1078.3665100003,419.3835600001001,961.7781800001002,1359.4557500003987,1310.9554500003997,568.9778800002001,1233.4728000001987,1712.044990000598,1566.644690000599,428.6778800002001,1007.0784800001993,1445.5506800005985,1358.3503800005992,720.1835600001002,1155.4608400002999,1436.2438000003995,1150.0548600004,1222.1722000002,1772.649480000399,2138.4267600005987,1463.1378200005997,1477.3665200002997,2157.0438000004983,2640.8154000007976,1836.3264600007988,1366.1659200002998,1965.6432000004988,2412.414800000798,1662.026160000799,226.3886400001001,671.8832600001001,1012.8665100002999,1078.3665100003,419.3835600001001,961.7781800001002,1359.4557500003987,1310.9554500003997,568.9778800002001,1233.4728000001987,1712.044990000598,1566.644690000599,428.6778800002001,1007.0784800001993,1445.5506800005985,1358.3503800005992,720.1835600001002,1155.4608400002999,1436.2438000003995,1150.0548600004,1222.1722000002,1772.649480000399,2138.4267600005987,1463.1378200005997,1477.3665200002997,2157.0438000004983,2640.8154000007976,1836.3264600007988,1366.1659200002998,1965.6432000004988,2412.414800000798,1662.026160000799,226.3886400001001,671.8832600001001,1012.8665100002999,1078.3665100003,419.3835600001001,961.7781800001002,1359.4557500003987,1310.9554500003997,568.9778800002001,1233.4728000001987,1712.044990000598,1566.644690000599,428.6778800002001,1007.0784800001993,1445.5506800005985,1358.3503800005992])
    (rev' @Double @4 conv2dCPadded128c t128b)

-- The gradient is the same as above, because one argument is the same
-- and convolution is linear.
testKonstNotBigCPadded128cb :: Assertion
testKonstNotBigCPadded128cb =
  assertEqualUpToEpsilon' 1e-8
    (ringestData [4,2,4,4] [720.1835600001002,1155.4608400002999,1436.2438000003995,1150.0548600004,1222.1722000002,1772.649480000399,2138.4267600005987,1463.1378200005997,1477.3665200002997,2157.0438000004983,2640.8154000007976,1836.3264600007988,1366.1659200002998,1965.6432000004988,2412.414800000798,1662.026160000799,226.3886400001001,671.8832600001001,1012.8665100002999,1078.3665100003,419.3835600001001,961.7781800001002,1359.4557500003987,1310.9554500003997,568.9778800002001,1233.4728000001987,1712.044990000598,1566.644690000599,428.6778800002001,1007.0784800001993,1445.5506800005985,1358.3503800005992,720.1835600001002,1155.4608400002999,1436.2438000003995,1150.0548600004,1222.1722000002,1772.649480000399,2138.4267600005987,1463.1378200005997,1477.3665200002997,2157.0438000004983,2640.8154000007976,1836.3264600007988,1366.1659200002998,1965.6432000004988,2412.414800000798,1662.026160000799,226.3886400001001,671.8832600001001,1012.8665100002999,1078.3665100003,419.3835600001001,961.7781800001002,1359.4557500003987,1310.9554500003997,568.9778800002001,1233.4728000001987,1712.044990000598,1566.644690000599,428.6778800002001,1007.0784800001993,1445.5506800005985,1358.3503800005992,720.1835600001002,1155.4608400002999,1436.2438000003995,1150.0548600004,1222.1722000002,1772.649480000399,2138.4267600005987,1463.1378200005997,1477.3665200002997,2157.0438000004983,2640.8154000007976,1836.3264600007988,1366.1659200002998,1965.6432000004988,2412.414800000798,1662.026160000799,226.3886400001001,671.8832600001001,1012.8665100002999,1078.3665100003,419.3835600001001,961.7781800001002,1359.4557500003987,1310.9554500003997,568.9778800002001,1233.4728000001987,1712.044990000598,1566.644690000599,428.6778800002001,1007.0784800001993,1445.5506800005985,1358.3503800005992,720.1835600001002,1155.4608400002999,1436.2438000003995,1150.0548600004,1222.1722000002,1772.649480000399,2138.4267600005987,1463.1378200005997,1477.3665200002997,2157.0438000004983,2640.8154000007976,1836.3264600007988,1366.1659200002998,1965.6432000004988,2412.414800000798,1662.026160000799,226.3886400001001,671.8832600001001,1012.8665100002999,1078.3665100003,419.3835600001001,961.7781800001002,1359.4557500003987,1310.9554500003997,568.9778800002001,1233.4728000001987,1712.044990000598,1566.644690000599,428.6778800002001,1007.0784800001993,1445.5506800005985,1358.3503800005992])
    (rev' @Double @4 conv2dCPadded128c t128b)
-}


-- * Disparity and misc

-- | Disparity cost volume.
--
--   Take two arrays of multi channel 2d images, where the first contains
--   left views of the scene and the second contains right views.
--
--   For each pair of images, slice the right image over the left image,
--   and for each offset produce the L1 distance indicating how well
--   correponding
--   multi-channel image elements in the right image match those in the left.
--
--   Described in:
--    Anytime Stereo Image Depth Estimation on Mobile Devices
--    Wang, Lai et al, ICRA 2019
--    https://arxiv.org/abs/1810.11408
--    Section III b).
--
costVolume
  :: forall r target. (ADReady target, GoodScalar r)
  => Int -> Int -> target (TKR 4 r) -> target (TKR 4 r) -> target (TKR 4 r)
costVolume iStart nCount arrL arrR =
  let [nImgs, nChas, nRows, nCols] = rshape arrL
      shO = [nImgs, nCount, nRows, nCols]
  in rbuild shO $ \[iImg, iDisp, iRow, iCol] ->
       let arrVecL = rbuild (nChas :$: ZSR) $ \[iCha] ->
                       rindex0 arrL [iImg, iCha, iRow, iCol]
           iSrc = iCol - fromIntegral iStart - iDisp
           arrVecR = rbuild [nChas] $ \[iCha] ->
                       rindex0 arrR [iImg, iCha, iRow, iSrc]
       in rsum0 $ rzipWith1 (\xL xR -> abs (xL - xR)) arrVecL arrVecR

test_disparityKonst :: Assertion
test_disparityKonst = do
  let arrL :: ADReady target => target (TKR 4 Double)
      arrL = rreplicate0N [1, 2, 4, 6] (rscalar (-0.2))
      arrR :: ADReady target => target (TKR 4 Double)
      arrR = rreplicate0N [1, 2, 4, 6] (rscalar 0.3)
      arrO = costVolume @Double 0 4 arrL arrR
      arrDL = vjp (\aL -> costVolume 0 4 aL (rfromPrimal arrR)) arrL arrO
      arrDR = vjp (\aR -> costVolume 0 4 (rfromPrimal arrL) aR) arrR arrO
  assertEqualUpToEpsilon 1e-7
    (rconcrete $ Nested.rfromListPrimLinear [1,4,4,6] [1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,0.4,1.0,1.0,1.0,1.0,1.0,0.4,1.0,1.0,1.0,1.0,1.0,0.4,1.0,1.0,1.0,1.0,1.0,0.4,1.0,1.0,1.0,1.0,1.0,0.4,0.4,1.0,1.0,1.0,1.0,0.4,0.4,1.0,1.0,1.0,1.0,0.4,0.4,1.0,1.0,1.0,1.0,0.4,0.4,1.0,1.0,1.0,1.0,0.4,0.4,0.4,1.0,1.0,1.0,0.4,0.4,0.4,1.0,1.0,1.0,0.4,0.4,0.4,1.0,1.0,1.0,0.4,0.4,0.4,1.0,1.0,1.0])
    arrO
  assertEqualUpToEpsilon 1e-7
    (rconcrete $ Nested.rfromListPrimLinear [1,2,4,6] [-2.2,-2.8,-3.4,-4.0,-4.0,-4.0,-2.2,-2.8,-3.4,-4.0,-4.0,-4.0,-2.2,-2.8,-3.4,-4.0,-4.0,-4.0,-2.2,-2.8,-3.4,-4.0,-4.0,-4.0,-2.2,-2.8,-3.4,-4.0,-4.0,-4.0,-2.2,-2.8,-3.4,-4.0,-4.0,-4.0,-2.2,-2.8,-3.4,-4.0,-4.0,-4.0,-2.2,-2.8,-3.4,-4.0,-4.0,-4.0])
    arrDL
  assertEqualUpToEpsilon 1e-7
    (rconcrete $ Nested.rfromListPrimLinear [1,2,4,6] [4.0,4.0,4.0,3.0,2.0,1.0,4.0,4.0,4.0,3.0,2.0,1.0,4.0,4.0,4.0,3.0,2.0,1.0,4.0,4.0,4.0,3.0,2.0,1.0,4.0,4.0,4.0,3.0,2.0,1.0,4.0,4.0,4.0,3.0,2.0,1.0,4.0,4.0,4.0,3.0,2.0,1.0,4.0,4.0,4.0,3.0,2.0,1.0])
   arrDR
  assertEqualUpToEpsilon' 1e-7
    (ringestData [1,2,4,6] [4.0,4.0,4.0,3.0,2.0,1.0,4.0,4.0,4.0,3.0,2.0,1.0,4.0,4.0,4.0,3.0,2.0,1.0,4.0,4.0,4.0,3.0,2.0,1.0,4.0,4.0,4.0,3.0,2.0,1.0,4.0,4.0,4.0,3.0,2.0,1.0,4.0,4.0,4.0,3.0,2.0,1.0,4.0,4.0,4.0,3.0,2.0,1.0])
    (rev' @Double @4 (costVolume 0 4 arrL) arrR)
  assertEqualUpToEpsilon' 1e-7
    (ringestData [1,2,4,6] [-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0])
    (rev' @Double @4 (\aL -> costVolume 0 2 aL arrR) arrL)
  assertEqualUpToEpsilon' 1e-7
    (ringestData [1,2,4,6] [2.0,2.0,2.0,2.0,2.0,1.0,2.0,2.0,2.0,2.0,2.0,1.0,2.0,2.0,2.0,2.0,2.0,1.0,2.0,2.0,2.0,2.0,2.0,1.0,2.0,2.0,2.0,2.0,2.0,1.0,2.0,2.0,2.0,2.0,2.0,1.0,2.0,2.0,2.0,2.0,2.0,1.0,2.0,2.0,2.0,2.0,2.0,1.0])
    (rev' @Double @4 (costVolume 0 2 arrL) arrR)

test_disparityKonst2 :: Assertion
test_disparityKonst2 = do
  let arrL :: (BaseTensor target, GoodScalar r, Differentiable r) => target (TKR 4 r)
      arrL = ringestData [1, 2, 4, 6] [0.4,0.4,0.4,1.0,1.0,1.0,0.4,0.4,0.4,1.0,1.0,1.0,0.4,0.4,0.4,1.0,1.0,1.0, 1.7041241452319316,1.21999,0.21355339059327375,0.7867666666666666,0.7331698975466578,0.6964466094067263,1.1,1.1041141452319316,0.42000000000000004,0.3536533905932737,0.78,1.253169897546658,1.1,0.50001,0.42000000000000004,0.2801,0.78,1.3,1.1,0.50001,0.42000000000000004,0.2801,0.78,1.3,2.808238290463863,1.21999,-0.5672067811865474,0.7867666666666666,1.986339795093316,0.6964466094067263]
      arrR :: (BaseTensor target, GoodScalar r, Differentiable r) => target (TKR 4 r)
      arrR = ringestData [1, 2, 4, 6] [0.2, 0.5, -0.2, 0.0001, 0.44, 0.9, -0.9, 0.00001, -0.22, -0.28, -0.34, -0.40, -0.40,-0.22,-0.28,-0.34, 0.22360679774997896,0.35355339059327373,0.20412414523193154,0.5, -0.35355339059327373,0.16666666666666666,0.17677669529663687,-0.25, -2.808238290463863,-1.21999,-0.5672067811865474,-0.7867666666666666,-1.986339795093316,-0.6964466094067263,2.808238290463863,1.21999,-0.5672067811865474,0.7867666666666666,0.6964466094067263,0.42000000000000004,0.3536533905932737,0.78,1.253169897546658,0.50001,0.42000000000000004,0.2801,0.78,1.1,0.50001,0.42000000000000004,0.2801,0.78]
      arrO = rreplicate0N [1, 4, 4, 6] (rscalar (1 :: Double))
      res1 = rconcrete $ Nested.rfromListPrimLinear [1,2,4,6] [4.0,2.0,2.0,4.0,4.0,4.0,4.0,4.0,4.0,4.0,4.0,4.0,4.0,4.0,4.0,4.0,4.0,4.0,4.0,4.0,2.0,4.0,4.0,4.0,4.0,4.0,4.0,4.0,4.0,4.0,2.0,0.0,0.0,-2.0,0.0,4.0,4.0,2.0,0.0,-4.0,1.0,4.0,4.0,4.0,-4.0,2.0,4.0,2.0]
      res2 = rconcrete $ Nested.rfromListPrimLinear [1,2,4,6] [-4.0,0.0,-4.0,-3.0,-2.0,-1.0,-4.0,-4.0,-4.0,-3.0,-2.0,-1.0,-4.0,-4.0,-4.0,-3.0,-2.0,-1.0,-4.0,-2.0,-4.0,-3.0,-2.0,-1.0,-4.0,-4.0,-4.0,-3.0,-2.0,-1.0,4.0,4.0,-4.0,1.0,-2.0,-1.0,-2.0,3.0,2.0,-1.0,-2.0,-1.0,-2.0,0.0,-2.0,-3.0,-2.0,1.0]
      arrDL :: Concrete (TKR 4 Double)
      arrDL = vjp (\aL -> costVolume 0 4 aL (rfromPrimal arrR)) arrL arrO
      arrDR :: Concrete (TKR 4 Double)
      arrDR = vjp (costVolume 0 4 (rfromPrimal arrL)) arrR arrO
  assertEqualUpToEpsilon 1e-7
    res1
    arrDL
  assertEqualUpToEpsilon 1e-7
    res2
    arrDR
  assertEqualUpToEpsilon' 1e-7
    res1
    (rev' @Double @4 (\aL -> costVolume 0 4 aL (rfromPrimal arrR)) arrL)
  assertEqualUpToEpsilon' 1e-7
    res2
    (rev' @Double @4 (costVolume 0 4 (rfromPrimal arrL)) arrR)

test_disparitySmall :: Assertion
test_disparitySmall = do
  let arrL :: ADReady target => target (TKR 4 Double)
      arrL = ringestData [1, 2, 3, 2] [0.2, 0.5, -0.2, 0.0001, 0.44, 0.9, -0.9, 0.00001, -0.22, -0.28, -0.34, -0.40]
      arrR :: ADReady target => target (TKR 4 Double)
      arrR = ringestData [1, 2, 3, 2] [-0.40,-0.22,-0.28,-0.34, 0.22360679774997896,0.35355339059327373,0.20412414523193154,0.5, -0.35355339059327373,0.16666666666666666,0.17677669529663687,-0.25]
      arrO = costVolume @Double 0 4 arrL arrR
      arrDL = vjp (\aL -> costVolume 0 4 aL (rfromPrimal arrR)) arrL arrO
      arrDR = vjp (\aR -> costVolume 0 4 (rfromPrimal arrL) aR) arrR arrO
  assertEqualUpToEpsilon 1e-7
    (rconcrete $ Nested.rfromListPrimLinear [1,4,3,2] [1.7041241452319316,1.21999,0.21355339059327375,0.7867666666666666,0.7331698975466578,0.6964466094067263,1.1,1.1041141452319316,0.42000000000000004,0.3536533905932737,0.78,1.253169897546658,1.1,0.50001,0.42000000000000004,0.2801,0.78,1.3,1.1,0.50001,0.42000000000000004,0.2801,0.78,1.3])
    arrO
  assertEqualUpToEpsilon' 1e-7
    (ringestData [1,2,3,2] [-2.0,-1.0,-2.0,-1.0,-2.0,-1.0,2.0,1.0,-2.0,1.0,2.0,1.0])
    (rev' @Double @4 (costVolume 0 4 arrL) arrR)
  assertEqualUpToEpsilon 1e-7
    (rconcrete $ Nested.rfromListPrimLinear [1,2,3,2] [5.004124145231932,3.3241241452319317,-1.0464466094067264,1.7006200572599404,3.0731698975466575,4.5496165069533845,-5.004124145231932,-1.3240841452319316,-1.0464466094067264,-0.9933132760733929,-3.0731698975466575,-4.5496165069533845])
    arrDL
  assertEqualUpToEpsilon 1e-7
    (rconcrete $ Nested.rfromListPrimLinear [1,2,3,2] [-2.808238290463863,-1.21999,-0.5672067811865474,-0.7867666666666666,-1.986339795093316,-0.6964466094067263,2.808238290463863,1.21999,-0.5672067811865474,0.7867666666666666,1.986339795093316,0.6964466094067263])
   arrDR
  assertEqualUpToEpsilon' 1e-7
    (ringestData [1,2,3,2] [-1.0,0.0,-1.0,0.0,-1.0,0.0,1.0,0.0,-1.0,0.0,1.0,0.0])
    (rev' @Double @4 (costVolume 1 4 arrL) arrR)
  assertEqualUpToEpsilon' 1e-7
    (ringestData [1,2,3,2] [2.0,2.0,-2.0,2.0,2.0,2.0,-2.0,2.0,-2.0,-2.0,-2.0,-2.0])
    (rev' @Double @4 (\aL -> costVolume 2 2 aL arrR) arrL)
  assertEqualUpToEpsilon' 1e-7
    (ringestData [1,2,3,2] [-1.0,0.0,-1.0,0.0,-1.0,0.0,1.0,0.0,-1.0,0.0,1.0,0.0])
    (rev' @Double @4 (costVolume 1 2 arrL) arrR)

codeTomsSlice :: ADReady target
              => target (TKR 2 Double) -> target (TKR 0 Double)
codeTomsSlice a =
  let (n, m) = case rshape a of
        [n', m'] -> (n', m')
        _ -> error "codeTomsSlice"
      a1 = rbuild @2 @0 [n,m-1] (\[i',j'] -> rindex0 a [i',j'])
      a2 = rbuild [n,m-1] (\[i',j'] -> rindex0 a [i',j' + 1])
  in rsum0 @2 $ rbuild [n,m] $ \[i, _j] ->
       rfromIndex0 i * rsum0 (a1 * a2)

testTomsSliceRev :: Assertion
testTomsSliceRev = do
  assertEqualUpToEpsilon 1e-5
    (ringestData [32,4] [63686.39999999999,137292.80000000002,121222.4,79558.40000000002,192646.40000000005,223971.0617601984,228556.80000000005,116846.33088019838,63686.39999999999,137292.80000000002,127174.4,79558.40000000002,192646.40000000005,158499.06176019844,202566.40000000005,51374.330880198424,11904.0,5952.0,7936.0,1984.0,116846.33088019838,385292.8000000001,227740.66176039676,192646.40000000005,116846.33088019838,228556.80000000005,174580.73088019836,35910.399999999994,79558.40000000002,127372.79999999997,143244.80000000002,63686.39999999999,105152.0,186683.13088000007,105151.98016,107124.73088000003,-396.79999999999995,26188.8,17459.2,25990.399999999998,-7936.0,73408.0,-1995.2691200000017,57536.0,51584.0,-660672.0,55552.0,3968.0,3968.0,3571.2,3571.2,-396.79999999999995,-396.79999999999995,49203.79519999998,49203.79519999998,49600.59519999998,49600.59519999998,49203.79519999998,49203.79519999998,-396.79999999999995,-396.79999999999995,49203.79519999998,49203.79519999998,49600.59519999998,49600.59519999998,129158.9952,65472.59519999998,79558.40000000002,-5952.0,73198.33087999995,51175.930880000036,51374.33087999995,51187.20000000001,1984.0000000000146,67059.20000000001,79558.40000000002,-5952.0,73198.33087999995,51175.930880000036,51374.33087999995,51187.20000000001,-21823.99999999993,108921.6,16070.400000000005,79558.40000000002,127372.79999999997,159116.80000000005,63686.39999999999,107124.73088000003,771974.4,218019.0617601984,192646.40000000005,170414.3308801984,385292.8000000001,340828.6617603968,192646.40000000005,57734.399999999994,99596.79999999999,137292.80000000002,63686.39999999999,79558.40000000002,127372.79999999997,159116.80000000005,63686.39999999999,107124.73088000003,236294.40000000005,271587.0617601984,192646.40000000005,45422.33088019842,385292.8000000001,162268.6617603968,192646.40000000005,57734.399999999994,99596.79999999999,137292.80000000002,63686.39999999999,79558.40000000002,127372.79999999997,159116.80000000005,63686.39999999999,107124.73088000003,369222.4,220003.0617601984,192646.40000000005,104942.33088019838,385292.8000000001,215836.66176039676,192646.40000000005])
    (grad (kfromR . codeTomsSlice) (rreshape [32, 4] t128))

testTomsSlice :: Assertion
testTomsSlice = do
  assertEqualUpToEpsilon' 1e-5
    (ringestData [32,4] [63686.39999999999,137292.80000000002,121222.4,79558.40000000002,192646.40000000005,223971.0617601984,228556.80000000005,116846.33088019838,63686.39999999999,137292.80000000002,127174.4,79558.40000000002,192646.40000000005,158499.06176019844,202566.40000000005,51374.330880198424,11904.0,5952.0,7936.0,1984.0,116846.33088019838,385292.8000000001,227740.66176039676,192646.40000000005,116846.33088019838,228556.80000000005,174580.73088019836,35910.399999999994,79558.40000000002,127372.79999999997,143244.80000000002,63686.39999999999,105152.0,186683.13088000007,105151.98016,107124.73088000003,-396.79999999999995,26188.8,17459.2,25990.399999999998,-7936.0,73408.0,-1995.2691200000017,57536.0,51584.0,-660672.0,55552.0,3968.0,3968.0,3571.2,3571.2,-396.79999999999995,-396.79999999999995,49203.79519999998,49203.79519999998,49600.59519999998,49600.59519999998,49203.79519999998,49203.79519999998,-396.79999999999995,-396.79999999999995,49203.79519999998,49203.79519999998,49600.59519999998,49600.59519999998,129158.9952,65472.59519999998,79558.40000000002,-5952.0,73198.33087999995,51175.930880000036,51374.33087999995,51187.20000000001,1984.0000000000146,67059.20000000001,79558.40000000002,-5952.0,73198.33087999995,51175.930880000036,51374.33087999995,51187.20000000001,-21823.99999999993,108921.6,16070.400000000005,79558.40000000002,127372.79999999997,159116.80000000005,63686.39999999999,107124.73088000003,771974.4,218019.0617601984,192646.40000000005,170414.3308801984,385292.8000000001,340828.6617603968,192646.40000000005,57734.399999999994,99596.79999999999,137292.80000000002,63686.39999999999,79558.40000000002,127372.79999999997,159116.80000000005,63686.39999999999,107124.73088000003,236294.40000000005,271587.0617601984,192646.40000000005,45422.33088019842,385292.8000000001,162268.6617603968,192646.40000000005,57734.399999999994,99596.79999999999,137292.80000000002,63686.39999999999,79558.40000000002,127372.79999999997,159116.80000000005,63686.39999999999,107124.73088000003,369222.4,220003.0617601984,192646.40000000005,104942.33088019838,385292.8000000001,215836.66176039676,192646.40000000005])
    (rev' codeTomsSlice (rreshape [32, 4] t128))


-- * PP Tests

testTomsSlicePP :: Assertion
testTomsSlicePP = do
  resetVarCounter
  let artifactRev = revArtifactAdapt UseIncomingCotangent codeTomsSlice (FTKR [32, 4] FTKScalar)
  printArtifactPrimalPretty (simplifyArtifact artifactRev)
    @?= "\\m1 -> rfromS (sscalar 4.0 * sdot0 (sconcrete (sfromListLinear [32] [0.0,1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0,11.0,12.0,13.0,14.0,15.0,16.0,17.0,18.0,19.0,20.0,21.0,22.0,23.0,24.0,25.0,26.0,27.0,28.0,29.0,30.0,31.0])) (sreplicate @32 (sdot0 (sslice (SNat @0) (SNat @3) (str (sfromR m1))) (sslice (SNat @1) (SNat @3) (str (sfromR m1))))))"
  printArtifactPrimalPretty artifactRev
    @?= "\\m1 -> let v8 = sreplicate @32 (ssum @96 (sreshape @[96] (str (sslice (SNat @0) (SNat @3) (str (sfromR m1))) * str (sslice (SNat @1) (SNat @3) (str (sfromR m1)))))) in rfromS (ssum @128 (sreshape @[128] (str (sreplicate @4 (siota (SNat @32) * v8)))))"
  printArtifactPretty artifactRev
    @?= "\\dret m1 -> let m10 = sreshape @[32,3] (sreplicate @96 (ssum @32 (siota (SNat @32) * ssum @4 (str (sreshape @[32,4] (sreplicate @128 (sfromR dret))))))) in rfromS (str (sappend (sconcrete (sfromListLinear [0,32] [])) (sappend (str (str (sslice (SNat @1) (SNat @3) (str (sfromR m1))) * m10)) (sconcrete (sreplicate [1,32] 0.0)))) + str (sappend (sconcrete (sreplicate [1,32] 0.0)) (sappend (str (str (sslice (SNat @0) (SNat @3) (str (sfromR m1))) * m10)) (sconcrete (sfromListLinear [0,32] [])))))"
  printArtifactPretty (simplifyArtifact artifactRev)
    @?= "\\dret m1 -> rfromS (let x10 = sdot0 (sconcrete (sfromListLinear [32] [0.0,4.0,8.0,12.0,16.0,20.0,24.0,28.0,32.0,36.0,40.0,44.0,48.0,52.0,56.0,60.0,64.0,68.0,72.0,76.0,80.0,84.0,88.0,92.0,96.0,100.0,104.0,108.0,112.0,116.0,120.0,124.0])) (sreplicate @32 (sfromR dret)) in str (sappend (sslice (SNat @1) (SNat @3) (str (sfromR m1)) * sreplicate @3 (sreplicate @32 x10)) (sconcrete (sreplicate [1,32] 0.0))) + str (sappend (sconcrete (sreplicate [1,32] 0.0)) (sslice (SNat @0) (SNat @3) (str (sfromR m1)) * sreplicate @3 (sreplicate @32 x10))))"

testCNNOPP0c :: Assertion
testCNNOPP0c = do
  resetVarCounter
  let artifactRev = revArtifactAdapt UseIncomingCotangent conv2dCLaborious (FTKR [2, 2, 2, 2] (FTKScalar @Double))
  printArtifactPrimalPretty (simplifyArtifact artifactRev)
    @?= "\\u1 -> rfromS (ssum @8 (stranspose @[4,0,1,2,3] (sreshape @[2,2,2,2,8] (str (sreplicate @2 (stranspose @[1,2,3,0] (sreplicate @1 (stranspose @[4,0,1,5,2,3] (sgather (sfromVector (fromList [stranspose @[3,0,5,1,2,4] (sgather (stranspose @[4,2,0,3,1] (sgather (sconcrete (sfromListLinear [2,2,2,2] [5.0,2.0,-2.0,0.0,13.1,9.0,582934.0,2.99432,6.0,1.0,0.1,-0.2,8.0,-4.0,-335.0,26.0])) (\\[i86, i88] -> [i86 + i88]))) (\\[i40, i41] -> [i40 + i41])), sconcrete (sreplicate [2,2,2,2,2,2] 0.0)])) (\\[i42, i43, i44, i45] -> [ifH (notB (2 <=. i42 + i44) &&* notB (2 <=. i43 + i45)) 0 1, i42, i43, i44, i45])))))) * sreplicate @2 (str (sreplicate @2 (str (sreplicate @2 (str (sreplicate @1 (sfromR u1)))))))))))"
  printArtifactPrimalPretty artifactRev
    @?= "\\u1 -> let w46 = str (sreplicate @2 (stranspose @[1,2,3,0] (sreplicate @1 (stranspose @[4,0,1,5,2,3] (sgather (sfromVector (fromList [stranspose @[3,0,5,1,2,4] (sgather (stranspose @[4,2,0,3,1] (sgather (sconcrete (sfromListLinear [2,2,2,2] [5.0,2.0,-2.0,0.0,13.1,9.0,582934.0,2.99432,6.0,1.0,0.1,-0.2,8.0,-4.0,-335.0,26.0])) (\\[i38, i39] -> [i38 + i39]))) (\\[i40, i41] -> [i40 + i41])), sconcrete (sreplicate [2,2,2,2,2,2] 0.0)])) (\\[i42, i43, i44, i45] -> [ifH (notB (2 <=. i42 + i44) &&* notB (2 <=. i43 + i45)) 0 1, i42, i43, i44, i45])))))) in rfromS (ssum @8 (stranspose @[4,0,1,2,3] (sreshape @[2,2,2,2,8] (w46 * sreplicate @2 (str (sreplicate @2 (str (sreplicate @2 (str (sreplicate @1 (sfromR u1)))))))))))"
  printArtifactPretty artifactRev
    @?= "\\dret u1 -> let w46 = str (sreplicate @2 (stranspose @[1,2,3,0] (sreplicate @1 (stranspose @[4,0,1,5,2,3] (sgather (sfromVector (fromList [stranspose @[3,0,5,1,2,4] (sgather (stranspose @[4,2,0,3,1] (sgather (sconcrete (sfromListLinear [2,2,2,2] [5.0,2.0,-2.0,0.0,13.1,9.0,582934.0,2.99432,6.0,1.0,0.1,-0.2,8.0,-4.0,-335.0,26.0])) (\\[i38, i39] -> [i38 + i39]))) (\\[i40, i41] -> [i40 + i41])), sconcrete (sreplicate [2,2,2,2,2,2] 0.0)])) (\\[i42, i43, i44, i45] -> [ifH (notB (2 <=. i42 + i44) &&* notB (2 <=. i43 + i45)) 0 1, i42, i43, i44, i45])))))) in rfromS (ssum @1 (str (ssum @2 (str (ssum @2 (str (ssum @2 (w46 * sreshape @[2,2,2,2,1,2,2,2] (stranspose @[1,2,3,4,0] (sreplicate @8 (sfromR dret)))))))))))"
  printArtifactPretty (simplifyArtifact artifactRev)
    @?= "\\dret u1 -> rfromS (ssum @2 (ssum @2 (sdot1In (stranspose @[0,1,2,5,3,4] (sgather (sfromVector (fromList [stranspose @[3,0,5,1,4,2] (sgather (stranspose @[4,2,0,3,1] (sgather (sconcrete (sfromListLinear [2,2,2,2] [5.0,2.0,-2.0,0.0,13.1,9.0,582934.0,2.99432,6.0,1.0,0.1,-0.2,8.0,-4.0,-335.0,26.0])) (\\[i112, i114] -> [i112 + i114]))) (\\[i40, i41] -> [i40 + i41])), sconcrete (sreplicate [2,2,2,2,2,2] 0.0)])) (\\[i96, i97, i98, i103, i104] -> [ifH (notB (2 <=. i96 + i103) &&* notB (2 <=. i97 + i104)) 0 1, i96, i97, i103, i104]))) (stranspose @[4,2,3,1,5,6,7,0] (sreshape @[2,2,2,2,1,2,2,2] (stranspose @[1,2,3,4,0] (sreplicate @8 (sfromR dret)))) !$ [0]))))"

testCNNOPP0b :: Assertion
testCNNOPP0b = do
  resetVarCounter
  let artifactRev = revArtifactAdapt UseIncomingCotangent conv2dBLaborious (FTKR [2, 2, 2, 2] (FTKScalar @Double))
  printArtifactPrimalPretty (simplifyArtifact artifactRev)
    @?= "\\u1 -> rfromS (ssum @8 (stranspose @[4,0,1,2,3] (sreshape @[2,2,2,2,8] (sconcrete (sfromListLinear [2,2,2,2,1,2,2,2] [5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0]) * str (sreplicate @2 (stranspose @[1,2,3,0] (sreplicate @1 (stranspose @[4,0,1,5,2,3] (sgather (sfromVector (fromList [stranspose @[3,0,5,1,2,4] (sgather (stranspose @[4,2,0,3,1] (sgather (stranspose @[2,0,1] (sfromR u1)) (\\[i108, i110] -> [i108 + i110]))) (\\[i47, i48] -> [i47 + i48])), sconcrete (sreplicate [2,2,2,2,2,2] 0.0)])) (\\[i49, i50, i51, i52] -> [ifH (notB (2 <=. i49 + i51) &&* notB (2 <=. i50 + i52)) 0 1, i49, i50, i51, i52]))))))))))"
  printArtifactPrimalPretty artifactRev
    @?= "\\u1 -> let w53 = str (sreplicate @2 (stranspose @[1,2,3,0] (sreplicate @1 (stranspose @[4,0,1,5,2,3] (sgather (sfromVector (fromList [stranspose @[3,0,5,1,2,4] (sgather (stranspose @[4,2,0,3,1] (sgather (stranspose @[2,0,1] (sfromR u1)) (\\[i45, i46] -> [i45 + i46]))) (\\[i47, i48] -> [i47 + i48])), sconcrete (sreplicate [2,2,2,2,2,2] 0.0)])) (\\[i49, i50, i51, i52] -> [ifH (notB (2 <=. i49 + i51) &&* notB (2 <=. i50 + i52)) 0 1, i49, i50, i51, i52])))))) in rfromS (ssum @8 (stranspose @[4,0,1,2,3] (sreshape @[2,2,2,2,8] (sconcrete (sfromListLinear [2,2,2,2,1,2,2,2] [5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0]) * w53))))"
  printArtifactPretty artifactRev
    @?= "\\dret u1 -> let w59 = sscatter (stranspose @[1,2,4,5,0,3] (ssum @1 (stranspose @[3,0,1,2] (ssum @2 (str (sconcrete (sfromListLinear [2,2,2,2,1,2,2,2] [5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0]) * sreshape @[2,2,2,2,1,2,2,2] (stranspose @[1,2,3,4,0] (sreplicate @8 (sfromR dret))))))))) (\\[i55, i56, i57, i58] -> [ifH (notB (2 <=. i55 + i57) &&* notB (2 <=. i56 + i58)) 0 1, i55, i56, i57, i58]) in rfromS (stranspose @[1,2,0] (sscatter (stranspose @[2,4,1,3,0] (sscatter (stranspose @[1,3,4,0,5,2] (w59 !$ [0])) (\\[i60, i61] -> [i60 + i61]))) (\\[i62, i63] -> [i62 + i63])))"
  printArtifactPretty (simplifyArtifact artifactRev)
    @?= "\\dret u1 -> rfromS (stranspose @[1,2,0] (sscatter (stranspose @[2,4,1,3,0] (sscatter (stranspose @[0,2,4,5,1,6,3] (sscatter (sdot1In (sconcrete (sfromListLinear [2,2,2,2,2,2,2] [5.0,13.1,-2.0,582934.0,5.0,13.1,-2.0,582934.0,2.0,9.0,0.0,2.99432,2.0,9.0,0.0,2.99432,6.0,8.0,0.1,-335.0,6.0,8.0,0.1,-335.0,1.0,-4.0,-0.2,26.0,1.0,-4.0,-0.2,26.0,5.0,13.1,-2.0,582934.0,5.0,13.1,-2.0,582934.0,2.0,9.0,0.0,2.99432,2.0,9.0,0.0,2.99432,6.0,8.0,0.1,-335.0,6.0,8.0,0.1,-335.0,1.0,-4.0,-0.2,26.0,1.0,-4.0,-0.2,26.0,5.0,13.1,-2.0,582934.0,5.0,13.1,-2.0,582934.0,2.0,9.0,0.0,2.99432,2.0,9.0,0.0,2.99432,6.0,8.0,0.1,-335.0,6.0,8.0,0.1,-335.0,1.0,-4.0,-0.2,26.0,1.0,-4.0,-0.2,26.0,5.0,13.1,-2.0,582934.0,5.0,13.1,-2.0,582934.0,2.0,9.0,0.0,2.99432,2.0,9.0,0.0,2.99432,6.0,8.0,0.1,-335.0,6.0,8.0,0.1,-335.0,1.0,-4.0,-0.2,26.0,1.0,-4.0,-0.2,26.0])) (stranspose @[4,2,3,6,7,0,5,1] (sreshape @[2,2,2,2,1,2,2,2] (stranspose @[1,2,3,4,0] (sreplicate @8 (sfromR dret)))) !$ [0])) (\\[i55, i56, i57, i58] -> [ifH (notB (2 <=. i55 + i57) &&* notB (2 <=. i56 + i58)) 0 1, i55, i56, i57, i58])) !$ [0]) (\\[i60, i61] -> [i60 + i61]))) (\\[i62, i63] -> [i62 + i63])))"

testCNNOPP1e :: Assertion
testCNNOPP1e = do
  resetVarCounter
  let f :: AstTensor AstMethodLet FullSpan
                     (TKProduct (TKR 4 Double) (TKR 4 Double))
        -> AstTensor AstMethodLet FullSpan
                     (TKR 4 Double)
      f v = conv2dUnpaddedL (tproject1 v) (tproject2 v)
      ftk = FTKProduct (FTKR (2 :$: 2 :$: 2 :$: 2 :$: ZSR) FTKScalar)
                       (FTKR (2 :$: 2 :$: 2 :$: 2 :$: ZSR) FTKScalar)
      (artifactRev, _) =
        revArtifactFromForwardPass
          UseIncomingCotangent (forwardPassByInterpretation f emptyEnv) ftk
  printArtifactPrimalPretty (simplifyArtifact artifactRev)
    @?= "\\u1 -> rfromS (ssum @8 (stranspose @[4,0,1,2,3] (sreshape @[2,2,2,2,8] (str (sreplicate @2 (stranspose @[4,0,1,5,2,3] (sgather (sfromVector (fromList [stranspose @[3,0,5,1,2,4] (sgather (stranspose @[4,2,0,3,1] (sgather (stranspose @[2,0,1] (sfromR (tproject2 u1))) (\\[i93, i95] -> [i93 + i95]))) (\\[i30, i31] -> [i30 + i31])), sconcrete (sreplicate [2,2,2,2,2,2] 0.0)])) (\\[i32, i33, i34, i35] -> [ifH (notB (2 <=. i32 + i34) &&* notB (2 <=. i33 + i35)) 0 1, i32, i33, i34, i35])))) * sreplicate @2 (str (sreplicate @2 (str (sreplicate @2 (sfromR (tproject1 u1))))))))))"
  printArtifactPrimalPretty artifactRev
    @?= "\\u1 -> let w36 = str (sreplicate @2 (stranspose @[4,0,1,5,2,3] (sgather (sfromVector (fromList [stranspose @[3,0,5,1,2,4] (sgather (stranspose @[4,2,0,3,1] (sgather (stranspose @[2,0,1] (sfromR (tproject2 u1))) (\\[i28, i29] -> [i28 + i29]))) (\\[i30, i31] -> [i30 + i31])), sconcrete (sreplicate [2,2,2,2,2,2] 0.0)])) (\\[i32, i33, i34, i35] -> [ifH (notB (2 <=. i32 + i34) &&* notB (2 <=. i33 + i35)) 0 1, i32, i33, i34, i35])))) in rfromS (ssum @8 (stranspose @[4,0,1,2,3] (sreshape @[2,2,2,2,8] (w36 * sreplicate @2 (str (sreplicate @2 (str (sreplicate @2 (sfromR (tproject1 u1))))))))))"
  printArtifactPretty (simplifyArtifact artifactRev)
    @?= "\\dret u1 -> tconvert (ConvT2 (ConvCmp (ConvXR STKScalar) (ConvCmp (ConvXX' (FTKX [2,2,2,2] FTKScalar)) ConvSX)) (ConvCmp (ConvXR STKScalar) (ConvCmp (ConvXX' (FTKX [2,2,2,2] FTKScalar)) ConvSX))) (STKProduct (STKS [2,2,2,2] STKScalar) (STKS [2,2,2,2] STKScalar)) (let w38 = sreshape @[2,2,2,2,2,2,2] (stranspose @[1,2,3,4,0] (sreplicate @8 (sfromR dret))) in tpair (ssum @2 (ssum @2 (sdot1In (stranspose @[2,3,0,4,5,6,1] (sreplicate @2 (stranspose @[4,0,1,5,2,3] (sgather (sfromVector (fromList [stranspose @[3,0,5,1,2,4] (sgather (stranspose @[4,2,0,3,1] (sgather (stranspose @[2,0,1] (sfromR (tproject2 u1))) (\\[i124, i126] -> [i124 + i126]))) (\\[i30, i31] -> [i30 + i31])), sconcrete (sreplicate [2,2,2,2,2,2] 0.0)])) (\\[i32, i33, i34, i35] -> [ifH (notB (2 <=. i32 + i34) &&* notB (2 <=. i33 + i35)) 0 1, i32, i33, i34, i35]))))) (stranspose @[2,3,1,4,5,6,0] w38)))) (stranspose @[1,2,0] (sscatter (stranspose @[2,4,1,3,0] (sscatter (stranspose @[0,2,4,5,1,6,3] (sscatter (sdot1In (stranspose @[2,3,5,6,0,4,1] (sreplicate @2 (str (sreplicate @2 (str (sreplicate @2 (sfromR (tproject1 u1)))))))) (stranspose @[2,3,5,6,0,4,1] w38)) (\\[i39, i40, i41, i42] -> [ifH (notB (2 <=. i39 + i41) &&* notB (2 <=. i40 + i42)) 0 1, i39, i40, i41, i42])) !$ [0]) (\\[i44, i45] -> [i44 + i45]))) (\\[i46, i47] -> [i46 + i47]))))"

testCNNOPP2 :: Assertion
testCNNOPP2 = do
  resetVarCounter
  let t = maxPool2dUnpadded2
            (rconcrete $ Nested.rreplicateScal (1 :$: 1 :$: 2 :$: 2 :$: ZSR) 1)
  printAstPretty (simplifyInlineContract t)
    @?= "rfromS (sreplicate @2 (sreplicate @2 (stranspose @[2,3,1,0] (sappend (sreplicate @1 (sgather (sreplicate @1 (stranspose @[2,0,1] (sgather (sconcrete (sfromListLinear [2,2] [1.0,1.0,1.0,1.0])) (\\[i68, i69] -> [i69 + i68])))) (\\[i44, i35, i8] -> [i8, i8, i8, 2 * i44 + i35]))) (sconcrete (sreplicate [1,2,2,2] 0.0))) !$ [0, 0])))"
  printAstPretty t
    @?= "rfromS (sreplicate @2 (sreplicate @2 (let u36 = let u41 = sgather (sgather (sreplicate @1 (let w32 = sgather (stranspose @[3,2,0,1] (sgather (sconcrete (sfromListLinear [2,3,2] [1.0,1.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0])) (\\[i26, i15] -> [i26 + i15]))) (\\[i22, i16] -> [i22 + i16]) in stranspose @[1,2,3,0] (sappend (sreplicate @1 (stranspose @[2,0,4,1,3] w32 !$ [0])) (sconcrete (sreplicate [2,2,2,2,2] 0.0))))) (\\[i20] -> [i20, i20, i20, 0])) (\\[i44, i39, i35, i8] -> [2 * i39 + i8, i39, 2 * i44 + i35]) in str (sappend (sreplicate @1 (str u41 !$ [0])) (sconcrete (sreplicate [1,2,2,2] 0.0))) in stranspose @[2,3,0,1] u36 !$ [0, 0])))"

testCNNOPP2b :: Assertion
testCNNOPP2b = do
  resetVarCounter
  let artifactRev = revArtifactAdapt UseIncomingCotangent maxPool2dUnpadded2 (FTKR [1, 1, 2, 2] (FTKScalar @Double))
  printArtifactPrimalPretty (simplifyArtifact artifactRev)
    @?= "\\u1 -> rfromS (sreplicate @2 (sreplicate @2 (stranspose @[2,3,1,0] (sappend (sreplicate @1 (sgather (sreplicate @1 (stranspose @[2,0,1] (sgather (sfromR u1 !$ [0, 0]) (\\[i92, i93] -> [i93 + i92])))) (\\[i94, i95, i96] -> [i96, i96, i96, 2 * i94 + i95]))) (sconcrete (sreplicate [1,2,2,2] 0.0))) !$ [0, 0])))"
  printArtifactPrimalPretty artifactRev
    @?= "\\u1 -> rfromS (sreplicate @2 (sreplicate @2 (stranspose @[2,3,1,0] (sappend (sreplicate @1 (sgather (sreplicate @1 (stranspose @[2,0,1] (sgather (sfromR u1 !$ [0, 0]) (\\[i92, i93] -> [i93 + i92])))) (\\[i94, i95, i96] -> [i96, i96, i96, 2 * i94 + i95]))) (sconcrete (sreplicate [1,2,2,2] 0.0))) !$ [0, 0])))"
  printArtifactPretty artifactRev
    @?= "\\dret u1 -> let u98 = stranspose @[3,2,0,1] (soneHot (ssum @2 (ssum @2 (sfromR dret))) [0, 0]) in rfromS (soneHot (sscatter (stranspose @[1,2,0] (ssum @1 (sscatter (ssum @1 (sslice (SNat @0) (SNat @1) u98)) (\\[i99, i100, i101] -> [i101, i101, i101, 2 * i99 + i100])))) (\\[i102, i103] -> [i103 + i102])) [0, 0])"
  printArtifactPretty (simplifyArtifact artifactRev)
    @?= "\\dret u1 -> rfromS (sreplicate @1 (sreplicate @1 (sscatter (sscatter (stranspose @[3,2,0,1] (soneHot (ssum @2 (ssum @2 (sfromR dret))) [0, 0]) !$ [0]) (\\[i99, i100, i101] -> [i101, i101, 2 * i99 + i100, i101]) !$ [0]) (\\[i102, i103] -> [i103 + i102]))))"

maxPool2dUnpadded2
  :: (target ~ AstTensor AstMethodLet FullSpan, r ~ Double)
  => target (TKR 4 r) -> target (TKR 4 r)
maxPool2dUnpadded2 a =
  rbuild [2, 2, 2, 2] $ \case
    [_, _, iBh, iBw] ->
      let arrt = slicez2 (conv2dUnpadded2 a) [iBw, 1, 2 * iBh, 2 * iBw]
      in rmaximum2 arrt
    _ -> error "maxPool2dUnpadded2: impossible pattern needlessly required"

conv2dUnpadded2
  :: (target ~ AstTensor AstMethodLet FullSpan, r ~ Double)
  => target (TKR 4 r) -> target (TKR 4 r)
conv2dUnpadded2 a =
  rbuild [3, 3, 2, 2] $ \case
    [iImg, _, iBh, iBw] ->
      let arrAt = slicez2 a [iImg, 0, iBh, iBw]
      in rindex0 arrAt [0, iBw, iBw, 0]
    _ -> error "conv2dUnpadded2: impossible pattern needlessly required"

slicez2
  :: (target ~ AstTensor AstMethodLet FullSpan, r ~ Double, n ~ 4)
  => target (TKR n r) -> IxROf target n -> target (TKR n r)
slicez2 d ixBase =
  rbuild [1, 1, 2, 2] $ \ixResult -> indexz02 d (ixrZipWith (+) ixBase ixResult)

indexz02
  :: forall target r n.
     (target ~ AstTensor AstMethodLet FullSpan, r ~ Double, n ~ 4)
  => target (TKR n r) -> IxROf target n -> target (TKR 0 r)
indexz02 d ix = ifH (1 >. (toList ix !! 0)) (d ! ix) (rscalar 0)

rmaximum2 :: (target ~ AstTensor AstMethodLet FullSpan, r ~ Double)
         => target (TKR 4 r) -> target (TKR 0 r)
rmaximum2 t0 = tlet t0 $ \t -> rindex0 t [0, 0, 0, 0]

{- TODO: divergent result; bring back when GHC 9.10 dropped:
testCNNOPP3 :: Assertion
testCNNOPP3 = do
  resetVarCounter
  let blackGlyph :: AstTensor AstMethodLet FullSpan (TKR 4 Double)
      blackGlyph = AstFromPrimal $ AstReplicate (SNat @2) knownSTK
                   $ AstReplicate (SNat @2) knownSTK
                   $ AstReplicate (SNat @2) knownSTK
                   $ AstReplicate (SNat @2) knownSTK
                       (rconcrete $ Nested.rscalar 7
                        :: AstTensor AstMethodLet PrimalSpan (TKR 0 Double))
      afcnn2T :: AstTensor AstMethodLet FullSpan (TKR 4 Double)
      afcnn2T = maxPool2dUnpadded33 $ conv2dUnpadded3 blackGlyph
  printAstPretty (simplifyInlineContract afcnn2T)
    @?= "rfromS (sreplicate @2 (sgather (stranspose @[2,1,0,4,3] (sappend (sreplicate @1 (sgather (sconcrete (sfromListLinear [2] [7.0,0.0])) (\\[i18, i22, i17, i15] -> [ifH (notB (2 <=. remH i22 4 + i18) &&* (notB (2 <=. i22 + i17) &&* notB (2 <=. i22 + i15))) 0 1]))) (sconcrete (sfromListLinear [1,2,2,2,2] [0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0])))) (\\[i52, i51] -> [remH i51 4, i52, i52, remH i51 4])))"
  printAstPretty afcnn2T
    @?= "rfromS (let w30 = sgather (sfromVector (fromList [stranspose @[4,0,1,2,5,3] (sgather (stranspose @[1,2,4,5,0,3] (sgather (sappend (sreplicate @1 (sgather (sconcrete (sfromListLinear [2] [7.0,0.0])) (\\[i18, i22, i17, i15] -> [ifH (notB (2 <=. remH i22 4 + i18) &&* (notB (2 <=. i22 + i17) &&* notB (2 <=. i22 + i15))) 0 1]))) (sreplicate @1 (sreplicate @2 (sreplicate @2 (sreplicate @2 (sreplicate @2 (sscalar 0.0))))))) (\\[i43, i38, i29, i7] -> [i43 + i7, i43 + i7, remH i38 4 + i29]))) (\\[i37, i33, i28, i8] -> [i37, i28, i33 + i8, remH i37 4 + i28])), sreplicate @2 (sreplicate @2 (sreplicate @2 (sreplicate @2 (sreplicate @2 (sreplicate @2 (sscalar 0.0))))))])) (\\[i46, i41, i36, i32, i27, i26, i24, i23] -> [ifH (notB (2 <=. remH i36 4 + i27) &&* (notB (2 <=. i46 + i26) &&* (notB (2 <=. i41 + i24) &&* notB (2 <=. i32 + i23)))) 0 1, i41, i36, i32, i27, i24, i23]) in stranspose @[4,5,6,7,0,1,2,3] w30 !$ [0, 0, 0, 0])"
-}

testCNNOPP3b :: Assertion
testCNNOPP3b = do
  resetVarCounter
  let artifactRev = revArtifactAdapt UseIncomingCotangent (maxPool2dUnpadded33 . conv2dUnpadded3) (FTKR [2, 2, 2, 2] (FTKScalar @Double))
  printArtifactPrimalPretty (simplifyArtifact artifactRev)
    @?= "\\u1 -> rfromS (stranspose @[2,1,0] (sgather (sgather (sappend (sreplicate @1 (stranspose @[0,4,5,1,2,3] (sgather (sfromVector (fromList [stranspose @[5,2,3,4,0,1] (sreplicate @2 (sreplicate @2 (stranspose @[2,3,1,0] (sgather (stranspose @[3,0,2,1] (sgather (stranspose @[3,0,1,2] (sgather (stranspose @[3,0,2,1] (sfromR u1) !$ [1]) (\\[i191, i193] -> [remH i191 4 + i193]))) (\\[i195, i197] -> [i195 + i197, i195]))) (\\[i128, i129] -> [i128 + i129, i128]))))), sconcrete (sreplicate [2,2,2,2,2,2] 0.0)])) (\\[i130, i131, i132, i133] -> [ifH (notB (2 <=. remH i130 4 + i131) &&* (notB (2 <=. i130 + i132) &&* notB (2 <=. i130 + i133))) 0 1, i130, i131, i132, i133])))) (sconcrete (sreplicate [1,2,2,2,2,2,2] 0.0))) (\\[i134, i135, i136, i137] -> [i135, i134, i135, i137, i135, i137, i134])) (\\[i138] -> [remH i138 4])))"
  printArtifactPrimalPretty artifactRev
    @?= "\\u1 -> rfromS (stranspose @[2,1,0] (sgather (sgather (sappend (sreplicate @1 (stranspose @[0,4,5,1,2,3] (sgather (sfromVector (fromList [stranspose @[5,2,3,4,0,1] (sreplicate @2 (sreplicate @2 (stranspose @[2,3,1,0] (sgather (stranspose @[3,0,2,1] (sgather (stranspose @[3,0,1,2] (sgather (stranspose @[3,0,2,1] (sfromR u1) !$ [1]) (\\[i124, i125] -> [remH i124 4 + i125]))) (\\[i126, i127] -> [i126 + i127, i126]))) (\\[i128, i129] -> [i128 + i129, i128]))))), sconcrete (sreplicate [2,2,2,2,2,2] 0.0)])) (\\[i130, i131, i132, i133] -> [ifH (notB (2 <=. remH i130 4 + i131) &&* (notB (2 <=. i130 + i132) &&* notB (2 <=. i130 + i133))) 0 1, i130, i131, i132, i133])))) (sconcrete (sreplicate [1,2,2,2,2,2,2] 0.0))) (\\[i134, i135, i136, i137] -> [i135, i134, i135, i137, i135, i137, i134])) (\\[i138] -> [remH i138 4])))"
  printArtifactPretty artifactRev
    @?= "\\dret u1 -> let w145 = sscatter (sscatter (stranspose @[2,1,0] (sfromR dret)) (\\[i140] -> [remH i140 4])) (\\[i141, i142, i143, i144] -> [i142, i141, i142, i144, i142, i144, i141]) ; w150 = sscatter (stranspose @[0,3,4,5,1,2] (ssum @1 (sslice (SNat @0) (SNat @1) w145))) (\\[i146, i147, i148, i149] -> [ifH (notB (2 <=. remH i146 4 + i147) &&* (notB (2 <=. i146 + i148) &&* notB (2 <=. i146 + i149))) 0 1, i146, i147, i148, i149]) in rfromS (stranspose @[1,3,2,0] (soneHot (sscatter (stranspose @[1,2,3,0] (sscatter (stranspose @[1,3,2,0] (sscatter (stranspose @[3,2,0,1] (ssum @2 (ssum @2 (stranspose @[4,5,1,2,3,0] (w150 !$ [0]))))) (\\[i151, i152] -> [i151 + i152, i151]))) (\\[i153, i154] -> [i153 + i154, i153]))) (\\[i155, i156] -> [remH i155 4 + i156])) [1]))"
  printArtifactPretty (simplifyArtifact artifactRev)
    @?= "\\dret u1 -> rfromS (stranspose @[1,3,2,0] (soneHot (sscatter (stranspose @[1,2,3,0] (sscatter (stranspose @[1,3,2,0] (sscatter (ssum @2 (ssum @2 (stranspose @[0,5,6,1,4,2,3] (sscatter (sscatter (sscatter (stranspose @[2,1,0] (sfromR dret)) (\\[i140] -> [remH i140 4])) (\\[i141, i142, i143, i144] -> [i142, i141, i142, i144, i141, i142, i144]) !$ [0]) (\\[i146, i147, i148, i149] -> [ifH (notB (2 <=. remH i146 4 + i147) &&* (notB (2 <=. i146 + i148) &&* notB (2 <=. i146 + i149))) 0 1, i146, i147, i148, i149])) !$ [0]))) (\\[i151, i152] -> [i151 + i152, i151]))) (\\[i153, i154] -> [i153 + i154, i153]))) (\\[i155, i156] -> [remH i155 4 + i156])) [1]))"

maxPool2dUnpadded3
  :: (ADReady target, GoodScalar r)
  => target (TKR 4 r) -> target (TKR 4 r)
maxPool2dUnpadded3 arr =
  rbuild [2, 2, 2, 2] $ \case
    [aa, bb, iBh, iBw] ->
      let arrt = slicez3 [2, 2, 2, 2] arr [iBh `quotH` 4, aa, bb, iBw]
      in rmaximum3 arrt
    _ -> error "maxPool2dUnpadded3: impossible pattern needlessly required"

maxPool2dUnpadded33
  :: (ADReady target, GoodScalar r)
  => target (TKR 4 r) -> target (TKR 4 r)
maxPool2dUnpadded33 arr =
  rbuild [2, 2, 2, 2] $ \case
    [aa, bb, iBh, iBw] ->
      let arrt = slicez33 [2, 2, 2, 2] arr [iBh `remH` 4, aa, bb, iBw]
      in rmaximum3 arrt
    _ -> error "maxPool2dUnpadded33: impossible pattern needlessly required"

conv2dUnpadded3
  :: (ADReady target, GoodScalar r)
  => target (TKR 4 r) -> target (TKR 4 r)
conv2dUnpadded3 arrA =
  let shB = [2, 2, 2, 2]
  in rbuild shB $ \case
    [iImg, _, iBh, iBw] ->
      let arrAt = slicez33 shB arrA [iImg `remH` 4, iImg, iImg, 1]
      in rindex0 arrAt [iBh, iBw, iImg, iBh]
    _ -> error "conv2dUnpadded3: impossible pattern needlessly required"

slicez3
  :: (ADReady target, GoodScalar r, KnownNat n)
  => IShR n -> target (TKR n r) -> IxROf target n -> target (TKR n r)
slicez3 shOut d ixBase =
  rbuild shOut $ \_ -> indexz03 d (ixrZipWith (+) ixBase ixBase)

slicez33
  :: (ADReady target, GoodScalar r, KnownNat n)
  => IShR n -> target (TKR n r) -> IxROf target n -> target (TKR n r)
slicez33 shOut d ixBase =
  rbuild shOut $ \ixResult -> indexz03 d (ixrZipWith (+) ixBase ixResult)

indexz03
  :: forall target r n. (ADReady target, GoodScalar r, KnownNat n)
  => target (TKR n r) -> IxROf target n -> target (TKR 0 r)
indexz03 d ix = ifH (within0 @target (rshape @target d) ix) (d ! ix) (rscalar 0)

rmaximum3 :: (BaseTensor target, LetTensor target, KnownNat n, GoodScalar r)
         => target (TKR n r) -> target (TKR 0 r)
rmaximum3 t0 = tlet t0 $ \t -> rindex0 t [0, 0, 0, 0]

testCNNOPP4 :: Assertion
testCNNOPP4 = do
  resetVarCounter
  let blackGlyph :: AstTensor AstMethodLet FullSpan (TKR 4 Double)
      blackGlyph = AstFromPrimal $ AstReplicate (SNat @3) knownSTK
                   $ AstReplicate (SNat @3) knownSTK
                   $ AstReplicate (SNat @3) knownSTK
                   $ AstReplicate (SNat @3) knownSTK
                       (rconcrete $ Nested.rscalar 7
                        :: AstTensor AstMethodLet PrimalSpan (TKR 0 Double))
      afcnn2T :: AstTensor AstMethodLet FullSpan (TKR 4 Double)
      afcnn2T = maxPool2dUnpadded4 blackGlyph
  printAstPretty (simplifyInlineContract afcnn2T)
    @?= "rfromS (str (sappend (sconcrete (sreplicate [1,2,2,2] 0.0)) (sreplicate @1 (stranspose @[2,1,0] (sgather (stranspose @[3,4,5,2,6,1,0] (sgather (stranspose @[4,1,3,0,2] (sgather (stranspose @[3,0,4,1,2] (sgather (sconcrete (sreplicate [2,3,3,3] 7.0)) (\\[i61, i64] -> [i61 + i64]))) (\\[i67, i69] -> [3 + (negate i69 + i67), i69]))) (\\[i71, i73, i76] -> [i71 * i73 + i76])) !$ [1, 0, 0, 0]) (\\[i84] -> [2 * i84]))))))"
      -- TODO: was once "rfromS (sconcrete (sfromListLinear [2,2,2,2] [0.0,0.0,0.0,0.0,7.0,7.0,7.0,7.0,0.0,0.0,0.0,0.0,7.0,7.0,7.0,7.0]))"
  printAstPretty afcnn2T
    @?= "rfromS (let w19 = sgather (sfromVector (fromList [stranspose @[3,0,5,6,1,2,4] (sgather (stranspose @[6,0,3,1,4,5,2] (sgather (stranspose @[3,0,2,1] (sgather (stranspose @[0,2,1] (sgather (sconcrete (sreplicate [2,3,3,3] 7.0)) (\\[i32, i5] -> [i32 + i5]))) (\\[i31, i6] -> [i31, 3 + (negate i31 + i6)]))) (\\[i36, i26, i7] -> [i36 * i26 + i7]))) (\\[i22, i8] -> [2 * i22 + i8])), sconcrete (sreplicate [2,2,2,2,2,2,2,2] 0.0)])) (\\[i28, i21, i15, i12, i9] -> [ifH (notB (2 <=. i28 + i15) &&* (notB (0 <=. negate i28 + i12) &&* notB (3 <=. 2 * i21 + i9))) 0 1, i28, i21, i15, i12, i9]) in stranspose @[2,3,4,7,5,0,6,1] w19 !$ [0, 0, 0, 0])"

testCNNOPP4b :: Assertion
testCNNOPP4b = do
  resetVarCounter
  let artifactRev = revArtifactAdapt UseIncomingCotangent maxPool2dUnpadded4 (FTKR [3, 3, 3, 3] (FTKScalar @Double))
  printArtifactPrimalPretty (simplifyArtifact artifactRev)
    @?= "\\u1 -> rfromS (str (sappend (sconcrete (sreplicate [1,2,2,2] 0.0)) (sreplicate @1 (stranspose @[2,1,0] (sgather (stranspose @[2,3,0,1] (sgather (stranspose @[1,0,3,2] (sreplicate @2 (stranspose @[2,3,0,1] (sreplicate @2 (stranspose @[2,1,0] (sreplicate @2 (sfromR u1 !$ [2, 2]))))))) (\\[i194, i195] -> [i195 * i194, i194, i195]))) (\\[i125] -> [i125, 2 * i125]))))))"
  printArtifactPrimalPretty artifactRev
    @?= "\\u1 -> rfromS (str (sappend (sconcrete (sreplicate [1,2,2,2] 0.0)) (sreplicate @1 (stranspose @[2,1,0] (sgather (stranspose @[3,5,6,2,4,7,1,0] (sgather (stranspose @[3,4,7,1,5,6,0,2] (sgather (stranspose @[6,0,7,4,3,2,1,5] (sgather (sslice (SNat @1) (SNat @2) (stranspose @[0,2,1] (sfromR u1))) (\\[i115, i116, i117, i118, i119] -> [i115 + i116]))) (\\[i120, i121] -> [3 + (negate i121 + i120), i121]))) (\\[i122, i123, i124] -> [i122, i123, i122 * i123 + i124])) !$ [1, 0, 0, 0]) (\\[i125] -> [i125, 2 * i125]))))))"
  printArtifactPretty artifactRev
    @?= "\\dret u1 -> rfromS (stranspose @[0,2,1] (sappend (sconcrete (sreplicate [1,3,3,3] 0.0)) (sappend (sscatter (stranspose @[1,6,5,4,3,7,0,2] (sscatter (stranspose @[6,3,7,0,1,4,5,2] (sscatter (stranspose @[7,6,3,0,4,1,2,5] (soneHot (sscatter (stranspose @[2,1,0] (ssum @1 (sslice (SNat @1) (SNat @1) (str (sfromR dret))))) (\\[i127] -> [i127, 2 * i127])) [1, 0, 0, 0])) (\\[i128, i129, i130] -> [i128, i129, i128 * i129 + i130]))) (\\[i131, i132] -> [3 + (negate i132 + i131), i132]))) (\\[i133, i134, i135, i136, i137] -> [i133 + i134])) (sconcrete (sfromListLinear [0,3,3,3] [])))))"
      -- TODO: was once "\\dret u1 -> rfromS (soneHot (sscatter (ssum @1 (sslice (SNat @1) (SNat @1) (str (sfromR dret)))) (\\[i86, i87, i88] -> [i86 * i87, 2 * i88])) [2, 2])"
  printArtifactPretty (simplifyArtifact artifactRev)
    @?= "\\dret u1 -> rfromS (sappend (sconcrete (sreplicate [1,3,3,3] 0.0)) (stranspose @[0,2,1] (sscatter (stranspose @[1,6,5,4,3,7,0,2] (sscatter (stranspose @[6,3,7,0,1,4,5,2] (sscatter (stranspose @[7,6,3,0,4,1,2,5] (soneHot (sscatter (stranspose @[1,3,2,0] (sfromR dret) !$ [1]) (\\[i127] -> [i127, 2 * i127])) [1, 0, 0, 0])) (\\[i128, i129, i130] -> [i128, i129, i128 * i129 + i130]))) (\\[i131, i132] -> [3 + (negate i132 + i131), i132]))) (\\[i133, i134, i135, i136, i137] -> [i133 + i134]))))"
      -- TODO: was once "\\dret u1 -> rfromS (soneHot (sscatter (str (sfromR dret) !$ [1]) (\\[i86, i87, i88] -> [i86 * i87, 2 * i88])) [2, 2])"

testCNNOPP5 :: Assertion
testCNNOPP5 = do
  resetVarCounter
  let blackGlyph :: AstTensor AstMethodLet FullSpan (TKR 4 Double)
      blackGlyph = AstFromPrimal $ AstReplicate (SNat @6) knownSTK
                   $ AstReplicate (SNat @6) knownSTK
                   $ AstReplicate (SNat @6) knownSTK
                   $ AstReplicate (SNat @6) knownSTK
                       (rconcrete $ Nested.rscalar 7
                        :: AstTensor AstMethodLet PrimalSpan (TKR 0 Double))
      afcnn2T :: AstTensor AstMethodLet FullSpan (TKR 4 Double)
      afcnn2T = conv2dUnpadded4 blackGlyph
  printAstPretty (simplifyInlineContract afcnn2T)
    @?= "rfromS (sconcrete (sreplicate [1,1,2,2] 7.0))"
  printAstPretty afcnn2T
    @?= "rfromS (sconcrete (sreplicate [1,1,2,2] 7.0))"

testCNNOPP5b :: Assertion
testCNNOPP5b = do
  resetVarCounter
  let artifactRev = revArtifactAdapt UseIncomingCotangent conv2dUnpadded4 (FTKR [5, 5, 5, 5] (FTKScalar @Double))
  printArtifactPrimalPretty (simplifyArtifact artifactRev)
    @?= "\\u1 -> rfromS (sreplicate @1 (sreplicate @1 (str (sslice (SNat @0) (SNat @2) (str (sslice (SNat @0) (SNat @2) (sfromR u1 !$ [0, 0])))))))"
  printArtifactPrimalPretty artifactRev
    @?= "\\u1 -> rfromS (sreplicate @1 (sreplicate @1 (str (sslice (SNat @0) (SNat @2) (str (sslice (SNat @0) (SNat @2) (sfromR u1 !$ [0, 0])))))))"
  printArtifactPretty artifactRev
    @?= "\\dret u1 -> rfromS (soneHot (sappend (sconcrete (sfromListLinear [0,5] [])) (sappend (str (sappend (sconcrete (sfromListLinear [0,2] [])) (sappend (str (ssum @1 (ssum @1 (sfromR dret)))) (sconcrete (sreplicate [3,2] 0.0))))) (sconcrete (sreplicate [3,5] 0.0)))) [0, 0])"
  printArtifactPretty (simplifyArtifact artifactRev)
    @?= "\\dret u1 -> rfromS (soneHot (sappend (str (sappend (stranspose @[0,1,3,2] (sfromR dret) !$ [0, 0]) (sconcrete (sreplicate [3,2] 0.0)))) (sconcrete (sreplicate [3,5] 0.0))) [0, 0])"

maxPool2dUnpadded4
  :: (ADReady target, GoodScalar r)
  => target (TKR 4 r) -> target (TKR 4 r)
maxPool2dUnpadded4 arr =
  rbuild [2, 2, 2, 2] $ \case
    [aa, bb, iBh, iBw] ->
      let arrt = slicez4 [2, 2, 2, 2] arr [bb + 1, 3 - bb, aa * iBh, 2 * iBw]
      in rmaximum3 arrt
    _ -> error "maxPool2dUnpadded4: impossible pattern needlessly required"

conv2dUnpadded4
  :: (ADReady target, GoodScalar r)
  => target (TKR 4 r) -> target (TKR 4 r)
conv2dUnpadded4 arrA =
  let shB = [1, 1, 2, 2]
  in rbuild shB $ \case
    [iImg, _, iBh, iBw] ->
      let arrAt = slicez4 shB arrA [iImg, 0, iBh, iBw]
      in rindex0 arrAt [0, 0, 0, 0]
    _ -> error "conv2dUnpadded4: impossible pattern needlessly required"

slicez4
  :: (ADReady target, GoodScalar r, KnownNat n)
  => IShR n -> target (TKR n r) -> IxROf target n -> target (TKR n r)
slicez4 shOut d ixBase =
  rbuild shOut $ \ixResult -> indexz03 d (ixrZipWith (+) ixBase ixResult)

testCNNOPP6 :: Assertion
testCNNOPP6 = do
  resetVarCounter
  let blackGlyph :: AstTensor AstMethodLet FullSpan (TKR 4 Double)
      blackGlyph = AstFromPrimal $ AstReplicate (SNat @2) knownSTK
                   $ AstReplicate (SNat @2) knownSTK
                   $ AstReplicate (SNat @2) knownSTK
                   $ AstReplicate (SNat @2) knownSTK
                       (rconcrete $ Nested.rscalar 7
                        :: AstTensor AstMethodLet PrimalSpan (TKR 0 Double))
      afcnn2T :: AstTensor AstMethodLet FullSpan (TKR 4 Double)
      afcnn2T = maxPool2dUnpadded3 $ conv2dUnpadded3z blackGlyph
  printAstPretty (simplifyInlineContract afcnn2T)
    @?= "rfromS (sconcrete (sfromListLinear [2,2,2,2] [7.0,0.0,7.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0]))"
  printAstPretty afcnn2T
    @?= "rfromS (stranspose @[1,2,0] (sreplicate @2 (let t30 = sgather (stranspose @[2,1,0] (sgather (str (sgather (sreplicate @2 (str (sreplicate @2 (let m21 = sgather (str (sgather (sconcrete (sreplicate [2,2,2,2] 7.0)) (\\[i9] -> [2 * i9, 2 * i9, 2 * i9]))) (\\[i12] -> [2 * i12]) in sappend (sreplicate @1 (sappend (sreplicate @1 (m21 !$ [0, 0])) (sconcrete (sfromListLinear [1] [0.0])))) (sconcrete (sreplicate [1,2] 0.0)))))) (\\[i1] -> [2 * i1, 0]))) (\\[i2] -> [2 * i2]))) (\\[i4] -> [2 * i4]) in sappend (sreplicate @1 (sappend (sreplicate @1 (sappend (sreplicate @1 (t30 !$ [0, 0, 0])) (sconcrete (sfromListLinear [1] [0.0])))) (sconcrete (sreplicate [1,2] 0.0)))) (sconcrete (sreplicate [1,2,2] 0.0)))))"


testCNNOPP6b :: Assertion
testCNNOPP6b = do
  resetVarCounter
  let artifactRev = revArtifactAdapt UseIncomingCotangent (maxPool2dUnpadded3 . conv2dUnpadded3z) (FTKR [2, 2, 2, 2] (FTKScalar @Double))
  printArtifactPrimalPretty (simplifyArtifact artifactRev)
    @?= "\\u1 -> rfromS (stranspose @[1,2,0] (sreplicate @2 (sappend (sreplicate @1 (sappend (sreplicate @1 (sappend (sreplicate @1 (sfromR u1 !$ [0, 0, 0, 0])) (sconcrete (sfromListLinear [1] [0.0])))) (sconcrete (sreplicate [1,2] 0.0)))) (sconcrete (sreplicate [1,2,2] 0.0)))))"
  printArtifactPrimalPretty artifactRev
    @?= "\\u1 -> rfromS (stranspose @[1,2,0] (sreplicate @2 (sappend (sreplicate @1 (sappend (sreplicate @1 (sappend (sreplicate @1 (sfromR u1 !$ [0, 0, 0, 0])) (sconcrete (sfromListLinear [1] [0.0])))) (sconcrete (sreplicate [1,2] 0.0)))) (sconcrete (sreplicate [1,2,2] 0.0)))))"
  printArtifactPretty artifactRev
    @?= "\\dret u1 -> let t34 = ssum @2 (stranspose @[2,0,1] (sfromR dret)) in rfromS (soneHot (ssum @1 (sslice (SNat @0) (SNat @1) (ssum @1 (sslice (SNat @0) (SNat @1) (ssum @1 (sslice (SNat @0) (SNat @1) t34)))))) [0, 0, 0, 0])"


  printArtifactPretty (simplifyArtifact artifactRev)
    @?= "\\dret u1 -> rfromS (soneHot (ssum0 (stranspose @[0,1,3,2] (sfromR dret) !$ [0, 0, 0])) [0, 0, 0, 0])"

conv2dUnpadded3z
  :: (ADReady target, GoodScalar r)
  => target (TKR 4 r) -> target (TKR 4 r)
conv2dUnpadded3z arrA =
  let shB = [2, 2, 2, 2]
  in rbuild shB $ \case
    [iImg, _, iBh, iBw] ->
      let arrAt = slicez3 shB arrA [iImg, iImg, iImg, iBw]
      in rindex0 arrAt [iBh, iBw, iImg, iBh]
    _ -> error "conv2dUnpadded3z: impossible pattern needlessly required"

testCNNOPP7 :: Assertion
testCNNOPP7 = do
  resetVarCounter
  let blackGlyph :: AstTensor AstMethodLet FullSpan (TKR 4 Double)
      blackGlyph = AstFromPrimal $ AstReplicate (SNat @2) knownSTK
                   $ AstReplicate (SNat @2) knownSTK
                   $ AstReplicate (SNat @2) knownSTK
                   $ AstReplicate (SNat @2) knownSTK
                       (rconcrete $ Nested.rscalar 7
                        :: AstTensor AstMethodLet PrimalSpan (TKR 0 Double))
      afcnn2T :: AstTensor AstMethodLet FullSpan (TKR 4 Double)
      afcnn2T = maxPool2dUnpadded3y $ conv2dUnpadded3y blackGlyph
  printAstPretty (simplifyInlineContract afcnn2T)
    @?= "rfromS (sconcrete (sfromListLinear [2,2,2,2] [7.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0]))"
  printAstPretty afcnn2T
    @?= "rfromS (let u27 = sgather (stranspose @[3,2,0,1] (sgather (stranspose @[1,2,0] (sgather (sreplicate @2 (stranspose @[1,2,0] (sreplicate @2 (let m21 = sgather (str (sgather (sconcrete (sreplicate [2,2,2,2] 7.0)) (\\[i9] -> [2 * i9, 2 * i9, 2 * i9]))) (\\[i11] -> [2 * i11]) in sappend (sreplicate @1 (sappend (sreplicate @1 (m21 !$ [0, 0])) (sconcrete (sfromListLinear [1] [0.0])))) (sconcrete (sreplicate [1,2] 0.0)))))) (\\[i1] -> [2 * i1]))) (\\[i31, i3] -> [2 * i3, 2 * i31]))) (\\[i4] -> [2 * i4]) in stranspose @[1,2,0] (sappend (sreplicate @1 (sappend (sreplicate @1 (sappend (sreplicate @1 (sappend (sreplicate @1 (u27 !$ [0, 0, 0, 0])) (sconcrete (sfromListLinear [1] [0.0])))) (sconcrete (sreplicate [1,2] 0.0)))) (sconcrete (sreplicate [1,2,2] 0.0)))) (sconcrete (sreplicate [1,2,2,2] 0.0))))"

testCNNOPP7b :: Assertion
testCNNOPP7b = do
  resetVarCounter
  let artifactRev = revArtifactAdapt UseIncomingCotangent (maxPool2dUnpadded3y . conv2dUnpadded3y) (FTKR [2, 2, 2, 2] (FTKScalar @Double))
  printArtifactPrimalPretty (simplifyArtifact artifactRev)
    @?= "\\u1 -> rfromS (stranspose @[1,2,0] (sappend (sreplicate @1 (sappend (sreplicate @1 (sappend (sreplicate @1 (sappend (sreplicate @1 (sfromR u1 !$ [0, 0, 0, 0])) (sconcrete (sfromListLinear [1] [0.0])))) (sconcrete (sreplicate [1,2] 0.0)))) (sconcrete (sreplicate [1,2,2] 0.0)))) (sconcrete (sreplicate [1,2,2,2] 0.0))))"
  printArtifactPrimalPretty artifactRev
    @?= "\\u1 -> rfromS (stranspose @[1,2,0] (sappend (sreplicate @1 (sappend (sreplicate @1 (sappend (sreplicate @1 (sappend (sreplicate @1 (sfromR u1 !$ [0, 0, 0, 0])) (sconcrete (sfromListLinear [1] [0.0])))) (sconcrete (sreplicate [1,2] 0.0)))) (sconcrete (sreplicate [1,2,2] 0.0)))) (sconcrete (sreplicate [1,2,2,2] 0.0))))"
  printArtifactPretty artifactRev
    @?= "\\dret u1 -> rfromS (soneHot (ssum @1 (sslice (SNat @0) (SNat @1) (ssum @1 (sslice (SNat @0) (SNat @1) (ssum @1 (sslice (SNat @0) (SNat @1) (ssum @1 (sslice (SNat @0) (SNat @1) (stranspose @[2,0,1] (sfromR dret)))))))))) [0, 0, 0, 0])"
  printArtifactPretty (simplifyArtifact artifactRev)
    @?= "\\dret u1 -> rfromS (soneHot (sfromR dret !$ [0, 0, 0, 0]) [0, 0, 0, 0])"

maxPool2dUnpadded3y
  :: (ADReady target, GoodScalar r)
  => target (TKR 4 r) -> target (TKR 4 r)
maxPool2dUnpadded3y arr =
  rbuild [2, 2, 2, 2] $ \case
    [aa, bb, iBh, iBw] ->
      let arrt = slicez3 [2, 2, 2, 2] arr [iBh, aa, bb, iBw]
      in rmaximum3 arrt
    _ -> error "maxPool2dUnpadded3y: impossible pattern needlessly required"

conv2dUnpadded3y
  :: (ADReady target, GoodScalar r)
  => target (TKR 4 r) -> target (TKR 4 r)
conv2dUnpadded3y arrA =
  let shB = [2, 2, 2, 2]
  in rbuild shB $ \case
    [iImg, _, iBh, iBw] ->
      let arrAt = slicez3 shB arrA [iImg, iImg, iImg, iBh]
      in rindex0 arrAt [iBh, iBw, iImg, iBh]
    _ -> error "conv2dUnpadded3y: impossible pattern needlessly required"

-- TODO: OOMs
_testPaddedCNNOPP0c :: Assertion
_testPaddedCNNOPP0c = do
  resetVarCounter
  let artifactRev = revArtifactAdapt UseIncomingCotangent conv2dCPadded (FTKR [2, 2, 2, 2] (FTKScalar @Double))
  printArtifactPrimalPretty (simplifyArtifact artifactRev)
    @?= "\\u1 -> rfromS (ssum @8 (stranspose @[4,0,1,2,3] (sreshape @[2,2,2,2,8] (str (sreplicate @2 (stranspose @[1,2,3,0] (sreplicate @1 (stranspose @[5,0,1,4,2,3] (sgather (sconcrete (sfromListLinear [4,4,2,2] [0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,5.0,13.1,-2.0,582934.0,2.0,9.0,0.0,2.99432,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,6.0,8.0,0.1,-335.0,1.0,-4.0,-0.2,26.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0])) (\\[i37, i38, i39, i40] -> [i37 + i39, i38 + i40])))))) * sreplicate @2 (str (sreplicate @2 (str (sreplicate @2 (str (sreplicate @1 (sfromR u1)))))))))))"
  printArtifactPrimalPretty artifactRev
    @?= "\\u1 -> let w41 = str (sreplicate @2 (stranspose @[1,2,3,0] (sreplicate @1 (stranspose @[5,0,1,4,2,3] (sgather (sconcrete (sfromListLinear [4,4,2,2] [0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,5.0,13.1,-2.0,582934.0,2.0,9.0,0.0,2.99432,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,6.0,8.0,0.1,-335.0,1.0,-4.0,-0.2,26.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0])) (\\[i37, i38, i39, i40] -> [i37 + i39, i38 + i40])))))) in rfromS (ssum @8 (stranspose @[4,0,1,2,3] (sreshape @[2,2,2,2,8] (w41 * sreplicate @2 (str (sreplicate @2 (str (sreplicate @2 (str (sreplicate @1 (sfromR u1)))))))))))"
  printArtifactPretty artifactRev
    @?= "\\dret u1 -> let w41 = str (sreplicate @2 (stranspose @[1,2,3,0] (sreplicate @1 (stranspose @[5,0,1,4,2,3] (sgather (sconcrete (sfromListLinear [4,4,2,2] [0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,5.0,13.1,-2.0,582934.0,2.0,9.0,0.0,2.99432,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,6.0,8.0,0.1,-335.0,1.0,-4.0,-0.2,26.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0])) (\\[i37, i38, i39, i40] -> [i37 + i39, i38 + i40])))))) in rfromS (ssum @1 (str (ssum @2 (str (ssum @2 (str (ssum @2 (w41 * sreshape @[2,2,2,2,1,2,2,2] (stranspose @[1,2,3,4,0] (sreplicate @8 (sfromR dret)))))))))))"
  printArtifactPretty (simplifyArtifact artifactRev)
    @?= "\\dret u1 -> rfromS (ssum @2 (ssum @2 (sdot1In (stranspose @[0,1,2,6,3,4,5] (sgather (sconcrete (sfromListLinear [4,4,2,2] [0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,5.0,-2.0,13.1,582934.0,2.0,0.0,9.0,2.99432,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,6.0,0.1,8.0,-335.0,1.0,-0.2,-4.0,26.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0])) (\\[i66, i67, i68, i73, i74] -> [i66 + i73, i67 + i74]))) (stranspose @[4,2,3,1,5,6,7,0] (sreshape @[2,2,2,2,1,2,2,2] (stranspose @[1,2,3,4,0] (sreplicate @8 (sfromR dret)))) !$ [0]))))"

-- TODO: OOMs
_testPaddedCNNOPP0b :: Assertion
_testPaddedCNNOPP0b = do
  resetVarCounter
  let artifactRev = revArtifactAdapt UseIncomingCotangent conv2dBPadded (FTKR [2, 2, 2, 2] (FTKScalar @Double))
  printArtifactPrimalPretty (simplifyArtifact artifactRev)
    @?= "\\u1 -> rfromS (ssum @8 (stranspose @[4,0,1,2,3] (sreshape @[2,2,2,2,8] (sconcrete (sfromListLinear [2,2,2,2,1,2,2,2] [5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0]) * str (sreplicate @2 (stranspose @[1,2,3,0] (sreplicate @1 (sgather (sappend (sconcrete (sreplicate [1,2,2,2,2,2,2,2,2,4] 0.0)) (stranspose @[9,2,5,1,3,4,6,7,8,0] (sappend (sconcrete (sreplicate [1,2,2,2,2,2,2,2,2,3] 0.0)) (stranspose @[9,1,2,3,4,5,6,7,8,0] (sappend (stranspose @[9,1,2,3,4,5,6,7,8,0] (sappend (sslice (SNat @1) (SNat @2) (stranspose @[9,1,2,3,4,5,6,7,8,0] (sslice (SNat @1) (SNat @2) (stranspose @[9,1,2,3,4,5,0,7,8,6] (sreplicate @2 (sreplicate @2 (sreplicate @2 (sreplicate @2 (sreplicate @2 (sreplicate @2 (sappend (sconcrete (sreplicate [1,2,2,4] 0.0)) (sappend (stranspose @[3,1,2,0] (sappend (sconcrete (sreplicate [1,2,2,2] 0.0)) (sappend (stranspose @[2,0,1] (sfromR u1)) (sconcrete (sreplicate [1,2,2,2] 0.0))))) (sconcrete (sreplicate [1,2,2,4] 0.0)))))))))))))) (sconcrete (sreplicate [1,2,2,2,2,2,2,2,2,2] 0.0)))) (sconcrete (sreplicate [1,2,2,2,2,2,2,2,2,3] 0.0))))))) (\\[i50, i51, i52, i53, i54, i55] -> [i51 + i54, i51, i54, i50, i52, i53, i55, i50, i53, i52 + i55])))))))))"
  printArtifactPrimalPretty artifactRev
    @?= "\\u1 -> let w56 = str (sreplicate @2 (stranspose @[1,2,3,0] (sreplicate @1 (sgather (sappend (sconcrete (sreplicate [1,2,2,2,2,2,2,2,2,4] 0.0)) (stranspose @[9,2,5,1,3,4,6,7,8,0] (sappend (sconcrete (sreplicate [1,2,2,2,2,2,2,2,2,3] 0.0)) (stranspose @[9,1,2,3,4,5,6,7,8,0] (sappend (stranspose @[9,1,2,3,4,5,6,7,8,0] (sappend (sslice (SNat @1) (SNat @2) (stranspose @[9,1,2,3,4,5,6,7,8,0] (sslice (SNat @1) (SNat @2) (stranspose @[9,1,2,3,4,5,0,7,8,6] (sreplicate @2 (sreplicate @2 (sreplicate @2 (sreplicate @2 (sreplicate @2 (sreplicate @2 (sappend (sconcrete (sreplicate [1,2,2,4] 0.0)) (sappend (stranspose @[3,1,2,0] (sappend (sconcrete (sreplicate [1,2,2,2] 0.0)) (sappend (stranspose @[2,0,1] (sfromR u1)) (sconcrete (sreplicate [1,2,2,2] 0.0))))) (sconcrete (sreplicate [1,2,2,4] 0.0)))))))))))))) (sconcrete (sreplicate [1,2,2,2,2,2,2,2,2,2] 0.0)))) (sconcrete (sreplicate [1,2,2,2,2,2,2,2,2,3] 0.0))))))) (\\[i50, i51, i52, i53, i54, i55] -> [i51 + i54, i51, i54, i50, i52, i53, i55, i50, i53, i52 + i55]))))) in rfromS (ssum @8 (stranspose @[4,0,1,2,3] (sreshape @[2,2,2,2,8] (sconcrete (sfromListLinear [2,2,2,2,1,2,2,2] [5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0]) * w56))))"
  printArtifactPretty artifactRev
    @?= "\\dret u1 -> let w64 = sscatter (ssum @1 (stranspose @[3,0,1,2] (ssum @2 (str (sconcrete (sfromListLinear [2,2,2,2,1,2,2,2] [5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0]) * sreshape @[2,2,2,2,1,2,2,2] (stranspose @[1,2,3,4,0] (sreplicate @8 (sfromR dret)))))))) (\\[i58, i59, i60, i61, i62, i63] -> [i59 + i62, i59, i62, i58, i60, i61, i63, i58, i61, i60 + i63]) ; u65 = ssum @2 (ssum @2 (ssum @2 (ssum @2 (ssum @2 (ssum @2 (stranspose @[6,1,2,3,4,5,9,7,8,0] (sappend (sconcrete (sreplicate [1,2,2,2,2,2,2,2,2,4] 0.0)) (sappend (stranspose @[9,1,2,3,4,5,6,7,8,0] (sappend (sconcrete (sreplicate [1,2,2,2,2,2,2,2,2,2] 0.0)) (sappend (sslice (SNat @0) (SNat @2) (stranspose @[9,1,2,3,4,5,6,7,8,0] (sslice (SNat @0) (SNat @2) (stranspose @[9,1,2,3,4,5,6,7,8,0] (sslice (SNat @1) (SNat @3) (stranspose @[9,3,1,4,5,2,6,7,8,0] (sslice (SNat @1) (SNat @3) w64))))))) (sconcrete (sreplicate [1,2,2,2,2,2,2,2,2,2] 0.0))))) (sconcrete (sreplicate [1,2,2,2,2,2,2,2,2,4] 0.0)))))))))) in rfromS (stranspose @[1,2,0] (sslice (SNat @0) (SNat @2) (sslice (SNat @1) (SNat @3) (stranspose @[3,1,2,0] (sslice (SNat @0) (SNat @2) (sslice (SNat @1) (SNat @3) u65))))))"
  printArtifactPretty (simplifyArtifact artifactRev)
    @?= "\\dret u1 -> rfromS (stranspose @[1,2,0] (sslice (SNat @1) (SNat @2) (stranspose @[3,1,2,0] (sslice (SNat @1) (SNat @2) (ssum @2 (ssum @2 (ssum @2 (ssum @2 (ssum @2 (ssum @2 (stranspose @[6,1,2,3,4,5,9,7,8,0] (sappend (sconcrete (sreplicate [1,2,2,2,2,2,2,2,2,4] 0.0)) (sappend (stranspose @[9,1,2,3,4,5,6,7,8,0] (sappend (sconcrete (sreplicate [1,2,2,2,2,2,2,2,2,2] 0.0)) (sappend (sslice (SNat @0) (SNat @2) (stranspose @[9,1,2,3,4,5,6,7,8,0] (sslice (SNat @0) (SNat @2) (stranspose @[9,1,2,3,4,5,6,7,8,0] (sslice (SNat @1) (SNat @3) (stranspose @[9,3,1,4,5,2,6,7,8,0] (sslice (SNat @1) (SNat @3) (sscatter (sdot1In (sconcrete (sfromListLinear [2,2,2,2,2,2,2] [5.0,13.1,2.0,9.0,6.0,8.0,1.0,-4.0,-2.0,582934.0,0.0,2.99432,0.1,-335.0,-0.2,26.0,5.0,13.1,2.0,9.0,6.0,8.0,1.0,-4.0,-2.0,582934.0,0.0,2.99432,0.1,-335.0,-0.2,26.0,5.0,13.1,2.0,9.0,6.0,8.0,1.0,-4.0,-2.0,582934.0,0.0,2.99432,0.1,-335.0,-0.2,26.0,5.0,13.1,2.0,9.0,6.0,8.0,1.0,-4.0,-2.0,582934.0,0.0,2.99432,0.1,-335.0,-0.2,26.0,5.0,13.1,2.0,9.0,6.0,8.0,1.0,-4.0,-2.0,582934.0,0.0,2.99432,0.1,-335.0,-0.2,26.0,5.0,13.1,2.0,9.0,6.0,8.0,1.0,-4.0,-2.0,582934.0,0.0,2.99432,0.1,-335.0,-0.2,26.0,5.0,13.1,2.0,9.0,6.0,8.0,1.0,-4.0,-2.0,582934.0,0.0,2.99432,0.1,-335.0,-0.2,26.0,5.0,13.1,2.0,9.0,6.0,8.0,1.0,-4.0,-2.0,582934.0,0.0,2.99432,0.1,-335.0,-0.2,26.0])) (stranspose @[4,0,2,3,5,6,7,1] (sreshape @[2,2,2,2,1,2,2,2] (stranspose @[1,2,3,4,0] (sreplicate @8 (sfromR dret)))) !$ [0])) (\\[i58, i59, i60, i61, i62, i63] -> [i59 + i62, i59, i62, i58, i60, i61, i63, i58, i61, i60 + i63]))))))))) (sconcrete (sreplicate [1,2,2,2,2,2,2,2,2,2] 0.0))))) (sconcrete (sreplicate [1,2,2,2,2,2,2,2,2,4] 0.0)))))))))))))))"

-- TODO: OOMs
_testPaddedCNNOPP1e :: Assertion
_testPaddedCNNOPP1e = do
  resetVarCounter
  let f :: AstTensor AstMethodLet FullSpan
                     (TKProduct (TKR 4 Double) (TKR 4 Double))
        -> AstTensor AstMethodLet FullSpan
                     (TKR 4 Double)
      f v = conv2dPadded (tproject1 v) (tproject2 v)
      ftk = FTKProduct (FTKR (2 :$: 2 :$: 2 :$: 2 :$: ZSR) FTKScalar)
                       (FTKR (2 :$: 2 :$: 2 :$: 2 :$: ZSR) FTKScalar)
      (artifactRev, _) =
        revArtifactFromForwardPass
          UseIncomingCotangent (forwardPassByInterpretation f emptyEnv) ftk
  printArtifactPrimalPretty (simplifyArtifact artifactRev)
    @?= "\\u1 -> rfromS (ssum @8 (stranspose @[4,0,1,2,3] (sreshape @[2,2,2,2,8] (str (sreplicate @2 (sgather (sappend (sconcrete (sreplicate [1,2,2,2,2,2,2,2,2,4] 0.0)) (stranspose @[9,2,5,1,3,4,6,7,8,0] (sappend (sconcrete (sreplicate [1,2,2,2,2,2,2,2,2,3] 0.0)) (stranspose @[9,1,2,3,4,5,6,7,8,0] (sappend (stranspose @[9,1,2,3,4,5,6,7,8,0] (sappend (sslice (SNat @1) (SNat @2) (stranspose @[9,1,2,3,4,5,6,7,8,0] (sslice (SNat @1) (SNat @2) (stranspose @[9,1,2,3,4,5,0,7,8,6] (sreplicate @2 (sreplicate @2 (sreplicate @2 (sreplicate @2 (sreplicate @2 (sreplicate @2 (sappend (sconcrete (sreplicate [1,2,2,4] 0.0)) (sappend (stranspose @[3,1,2,0] (sappend (sconcrete (sreplicate [1,2,2,2] 0.0)) (sappend (stranspose @[2,0,1] (sfromR (tproject2 u1))) (sconcrete (sreplicate [1,2,2,2] 0.0))))) (sconcrete (sreplicate [1,2,2,4] 0.0)))))))))))))) (sconcrete (sreplicate [1,2,2,2,2,2,2,2,2,2] 0.0)))) (sconcrete (sreplicate [1,2,2,2,2,2,2,2,2,3] 0.0))))))) (\\[i44, i75, i45, i46, i76, i47] -> [i75 + i76, i75, i76, i44, i45, i46, i47, i44, i46, i45 + i47]))) * sreplicate @2 (str (sreplicate @2 (str (sreplicate @2 (sfromR (tproject1 u1))))))))))"
  printArtifactPrimalPretty artifactRev
    @?= "\\u1 -> let u41 = sappend (sconcrete (sreplicate [1,2,2,4] 0.0)) (sappend (stranspose @[3,1,2,0] (sappend (sconcrete (sreplicate [1,2,2,2] 0.0)) (sappend (stranspose @[2,0,1] (sfromR (tproject2 u1))) (sconcrete (sreplicate [1,2,2,2] 0.0))))) (sconcrete (sreplicate [1,2,2,4] 0.0))) ; w48 = str (sreplicate @2 (stranspose @[0,4,1,2,5,3] (sgather (stranspose @[2,3,4,5,6,7,8,0,1] (sgather (sappend (sconcrete (sreplicate [1,2,2,2,2,2,2,2,2,4] 0.0)) (stranspose @[9,2,5,1,3,4,6,7,8,0] (sappend (sconcrete (sreplicate [1,2,2,2,2,2,2,2,2,3] 0.0)) (stranspose @[9,1,2,3,4,5,6,7,8,0] (sappend (stranspose @[9,1,2,3,4,5,6,7,8,0] (sappend (sslice (SNat @1) (SNat @2) (stranspose @[9,1,2,3,4,5,6,7,8,0] (sslice (SNat @1) (SNat @2) (stranspose @[9,1,2,3,4,5,0,7,8,6] (sreplicate @2 (sreplicate @2 (sreplicate @2 (sreplicate @2 (sreplicate @2 (sreplicate @2 u41)))))))))) (sconcrete (sreplicate [1,2,2,2,2,2,2,2,2,2] 0.0)))) (sconcrete (sreplicate [1,2,2,2,2,2,2,2,2,3] 0.0))))))) (\\[i42, i43] -> [i42 + i43, i42, i43]))) (\\[i44, i45, i46, i47] -> [i44, i45, i46, i47, i44, i46, i45 + i47])))) in rfromS (ssum @8 (stranspose @[4,0,1,2,3] (sreshape @[2,2,2,2,8] (w48 * sreplicate @2 (str (sreplicate @2 (str (sreplicate @2 (sfromR (tproject1 u1))))))))))"
  printArtifactPretty (simplifyArtifact artifactRev)
    @?= "\\dret u1 -> tconvert (ConvT2 (ConvCmp (ConvXR STKScalar) (ConvCmp (ConvXX' (FTKX [2,2,2,2] FTKScalar)) ConvSX)) (ConvCmp (ConvXR STKScalar) (ConvCmp (ConvXX' (FTKX [2,2,2,2] FTKScalar)) ConvSX))) (STKProduct (STKS [2,2,2,2] STKScalar) (STKS [2,2,2,2] STKScalar)) (let w50 = sreshape @[2,2,2,2,2,2,2] (stranspose @[1,2,3,4,0] (sreplicate @8 (sfromR dret))) in tpair (ssum @2 (ssum @2 (sdot1In (stranspose @[2,3,0,4,5,6,1] (sreplicate @2 (sgather (sappend (sconcrete (sreplicate [1,2,2,2,2,2,2,2,2,4] 0.0)) (stranspose @[9,2,5,1,3,4,6,7,8,0] (sappend (sconcrete (sreplicate [1,2,2,2,2,2,2,2,2,3] 0.0)) (stranspose @[9,1,2,3,4,5,6,7,8,0] (sappend (stranspose @[9,1,2,3,4,5,6,7,8,0] (sappend (sslice (SNat @1) (SNat @2) (stranspose @[9,1,2,3,4,5,6,7,8,0] (sslice (SNat @1) (SNat @2) (stranspose @[9,1,2,3,4,5,0,7,8,6] (sreplicate @2 (sreplicate @2 (sreplicate @2 (sreplicate @2 (sreplicate @2 (sreplicate @2 (sappend (sconcrete (sreplicate [1,2,2,4] 0.0)) (sappend (stranspose @[3,1,2,0] (sappend (sconcrete (sreplicate [1,2,2,2] 0.0)) (sappend (stranspose @[2,0,1] (sfromR (tproject2 u1))) (sconcrete (sreplicate [1,2,2,2] 0.0))))) (sconcrete (sreplicate [1,2,2,4] 0.0)))))))))))))) (sconcrete (sreplicate [1,2,2,2,2,2,2,2,2,2] 0.0)))) (sconcrete (sreplicate [1,2,2,2,2,2,2,2,2,3] 0.0))))))) (\\[i44, i84, i45, i46, i85, i47] -> [i84 + i85, i84, i85, i44, i45, i46, i47, i44, i46, i45 + i47])))) (stranspose @[2,3,1,4,5,6,0] w50)))) (stranspose @[1,2,0] (sslice (SNat @1) (SNat @2) (stranspose @[3,1,2,0] (sslice (SNat @1) (SNat @2) (ssum @2 (ssum @2 (ssum @2 (ssum @2 (ssum @2 (ssum @2 (stranspose @[6,1,2,3,4,5,9,7,8,0] (sappend (sconcrete (sreplicate [1,2,2,2,2,2,2,2,2,4] 0.0)) (sappend (stranspose @[9,1,2,3,4,5,6,7,8,0] (sappend (sconcrete (sreplicate [1,2,2,2,2,2,2,2,2,2] 0.0)) (sappend (sslice (SNat @0) (SNat @2) (stranspose @[9,1,2,3,4,5,6,7,8,0] (sslice (SNat @0) (SNat @2) (stranspose @[9,1,2,3,4,5,6,7,8,0] (sslice (SNat @1) (SNat @3) (stranspose @[9,3,1,4,5,2,6,7,8,0] (sslice (SNat @1) (SNat @3) (sscatter (stranspose @[7,8,0,1,2,3,4,5,6] (sscatter (sdot1In (sreplicate @2 (stranspose @[2,3,5,0,4,1] (sreplicate @2 (str (sreplicate @2 (sfromR (tproject1 u1))))))) (stranspose @[0,3,4,6,2,5,1] w50)) (\\[i51, i52, i53, i54] -> [i51, i52, i53, i54, i51, i53, i52 + i54]))) (\\[i55, i56] -> [i55 + i56, i55, i56]))))))))) (sconcrete (sreplicate [1,2,2,2,2,2,2,2,2,2] 0.0))))) (sconcrete (sreplicate [1,2,2,2,2,2,2,2,2,4] 0.0))))))))))))))))"

-- This is fragile due to indexing out of bounds, see above.
testPaddedCNNOPP1b :: Assertion
testPaddedCNNOPP1b = do
  resetVarCounter
  let f :: AstTensor AstMethodLet FullSpan
                     (TKProduct (TKR 4 Double) (TKR 4 Double))
        -> AstTensor AstMethodLet FullSpan
                     (TKR 4 Double)
      f v = conv2dShrinking (tproject1 v) (tproject2 v)
      ftk = FTKProduct (FTKR (2 :$: 2 :$: 2 :$: 2 :$: ZSR) FTKScalar)
                       (FTKR (6 :$: 2 :$: 6 :$: 6 :$: ZSR) FTKScalar)
      (artifactRev, _) =
        revArtifactFromForwardPass
          UseIncomingCotangent (forwardPassByInterpretation f emptyEnv) ftk
  printArtifactPrimalPretty (simplifyArtifact artifactRev)
    @?= "\\u1 -> rfromS (ssum @8 (stranspose @[4,0,1,2,3] (sreshape @[6,2,4,4,8] (str (sreplicate @2 (stranspose @[2,3,0,4,5,1] (sgather (stranspose @[4,2,0,3,1] (sgather (stranspose @[2,0,1] (sfromR (tproject2 u1))) (\\[i64, i66] -> [i64 + i66]))) (\\[i24, i25] -> [i24 + i25])))) * sreplicate @6 (str (sreplicate @4 (str (sreplicate @4 (sfromR (tproject1 u1))))))))))"
  printArtifactPrimalPretty artifactRev
    @?= "\\u1 -> let w26 = str (sreplicate @2 (stranspose @[2,3,0,4,5,1] (sgather (stranspose @[4,2,0,3,1] (sgather (stranspose @[2,0,1] (sfromR (tproject2 u1))) (\\[i22, i23] -> [i22 + i23]))) (\\[i24, i25] -> [i24 + i25])))) in rfromS (ssum @8 (stranspose @[4,0,1,2,3] (sreshape @[6,2,4,4,8] (w26 * sreplicate @6 (str (sreplicate @4 (str (sreplicate @4 (sfromR (tproject1 u1))))))))))"
  printArtifactPretty artifactRev
    @?= "\\dret u1 -> let w26 = str (sreplicate @2 (stranspose @[2,3,0,4,5,1] (sgather (stranspose @[4,2,0,3,1] (sgather (stranspose @[2,0,1] (sfromR (tproject2 u1))) (\\[i22, i23] -> [i22 + i23]))) (\\[i24, i25] -> [i24 + i25])))) ; w28 = sreshape @[6,2,4,4,2,2,2] (stranspose @[1,2,3,4,0] (sreplicate @8 (sfromR dret))) in tpair (rfromS (ssum @4 (str (ssum @4 (str (ssum @6 (w26 * w28))))))) (rfromS (stranspose @[1,2,0] (sscatter (stranspose @[2,4,1,3,0] (sscatter (stranspose @[2,5,0,1,3,4] (ssum @2 (str (sreplicate @6 (str (sreplicate @4 (str (sreplicate @4 (sfromR (tproject1 u1)))))) * w28)))) (\\[i29, i30] -> [i29 + i30]))) (\\[i31, i32] -> [i31 + i32]))))"
  printArtifactPretty (simplifyArtifact artifactRev)
    @?= "\\dret u1 -> tconvert (ConvT2 (ConvCmp (ConvXR STKScalar) (ConvCmp (ConvXX' (FTKX [2,2,2,2] FTKScalar)) ConvSX)) (ConvCmp (ConvXR STKScalar) (ConvCmp (ConvXX' (FTKX [6,2,6,6] FTKScalar)) ConvSX))) (STKProduct (STKS [2,2,2,2] STKScalar) (STKS [6,2,6,6] STKScalar)) (let w28 = sreshape @[6,2,4,4,2,2,2] (stranspose @[1,2,3,4,0] (sreplicate @8 (sfromR dret))) in tpair (ssum @4 (ssum @4 (sdot1In (stranspose @[2,3,0,4,5,6,1] (sreplicate @2 (stranspose @[2,3,0,4,5,1] (sgather (stranspose @[4,2,0,3,1] (sgather (stranspose @[2,0,1] (sfromR (tproject2 u1))) (\\[i81, i83] -> [i81 + i83]))) (\\[i24, i25] -> [i24 + i25]))))) (stranspose @[2,3,1,4,5,6,0] w28)))) (stranspose @[1,2,0] (sscatter (stranspose @[2,4,1,3,0] (sscatter (sdot1In (stranspose @[3,6,0,2,4,5,1] (sreplicate @6 (str (sreplicate @4 (str (sreplicate @4 (sfromR (tproject1 u1)))))))) (stranspose @[3,6,0,2,4,5,1] w28)) (\\[i29, i30] -> [i29 + i30]))) (\\[i31, i32] -> [i31 + i32]))))"

testPaddedCNNOPPLet :: Assertion
testPaddedCNNOPPLet = do
  resetVarCounter
  let f :: AstTensor AstMethodLet FullSpan
                     (TKProduct (TKR 4 Double) (TKR 4 Double))
        -> AstTensor AstMethodLet FullSpan
                     (TKR 4 Double)
      f v = conv2dPaddedLet (tproject1 v) (tproject2 v)
      ftk = FTKProduct (FTKR (2 :$: 2 :$: 2 :$: 2 :$: ZSR) FTKScalar)
                       (FTKR (2 :$: 2 :$: 2 :$: 2 :$: ZSR) FTKScalar)
      (artifactRev, _) =
        revArtifactFromForwardPass
          UseIncomingCotangent (forwardPassByInterpretation f emptyEnv) ftk
  printArtifactPrimalPretty (simplifyArtifact artifactRev)
    @?= "\\u1 -> rfromS (ssum @8 (stranspose @[4,0,1,2,3] (sreshape @[2,2,2,2,8] (str (sreplicate @2 (stranspose @[2,3,0,4,5,1] (sgather (stranspose @[4,2,0,3,1] (sgather (sappend (sconcrete (sreplicate [1,2,2,4] 0.0)) (stranspose @[3,1,2,0] (sappend (sconcrete (sreplicate [1,2,2,3] 0.0)) (stranspose @[0,2,3,1] (sgather (sslice (SNat @1) (SNat @3) (stranspose @[3,0,4,1,2] (sslice (SNat @1) (SNat @3) (stranspose @[3,1,2,4,0] (sfromVector (fromList [stranspose @[1,2,3,0] (sappend (sconcrete (sreplicate [1,2,2,4] 0.0)) (sappend (stranspose @[3,1,2,0] (sappend (sconcrete (sreplicate [1,2,2,2] 0.0)) (sappend (stranspose @[2,0,1] (sfromR (tproject2 u1))) (sconcrete (sreplicate [1,2,2,2] 0.0))))) (sconcrete (sreplicate [1,2,2,4] 0.0)))), sconcrete (sreplicate [2,2,4,4] 0.0)])))))) (\\[i36, i37] -> [i36, i37, ifH (notB (2 <=. i37) &&* notB (2 <=. i36)) 0 1])))))) (\\[i81, i83] -> [i81 + i83]))) (\\[i41, i42] -> [i41 + i42])))) * sreplicate @2 (str (sreplicate @2 (str (sreplicate @2 (sfromR (tproject1 u1))))))))))"
  printArtifactPrimalPretty artifactRev
    @?= "\\u1 -> let u35 = sappend (sconcrete (sreplicate [1,2,2,4] 0.0)) (sappend (stranspose @[3,1,2,0] (sappend (sconcrete (sreplicate [1,2,2,2] 0.0)) (sappend (stranspose @[2,0,1] (sfromR (tproject2 u1))) (sconcrete (sreplicate [1,2,2,2] 0.0))))) (sconcrete (sreplicate [1,2,2,4] 0.0))) ; u38 = stranspose @[1,2,0] (sappend (sconcrete (sreplicate [1,2,2,4] 0.0)) (stranspose @[3,1,2,0] (sappend (sconcrete (sreplicate [1,2,2,3] 0.0)) (stranspose @[0,2,3,1] (sgather (sslice (SNat @1) (SNat @3) (stranspose @[3,0,4,1,2] (sslice (SNat @1) (SNat @3) (stranspose @[3,1,2,4,0] (sfromVector (fromList [stranspose @[1,2,3,0] u35, sconcrete (sreplicate [2,2,4,4] 0.0)])))))) (\\[i36, i37] -> [i36, i37, ifH (notB (2 <=. i37) &&* notB (2 <=. i36)) 0 1])))))) ; w43 = str (sreplicate @2 (stranspose @[2,3,0,4,5,1] (sgather (stranspose @[4,2,0,3,1] (sgather (stranspose @[2,0,1] u38) (\\[i39, i40] -> [i39 + i40]))) (\\[i41, i42] -> [i41 + i42])))) in rfromS (ssum @8 (stranspose @[4,0,1,2,3] (sreshape @[2,2,2,2,8] (w43 * sreplicate @2 (str (sreplicate @2 (str (sreplicate @2 (sfromR (tproject1 u1))))))))))"
  printArtifactPretty (simplifyArtifact artifactRev)
    @?= "\\dret u1 -> tconvert (ConvT2 (ConvCmp (ConvXR STKScalar) (ConvCmp (ConvXX' (FTKX [2,2,2,2] FTKScalar)) ConvSX)) (ConvCmp (ConvXR STKScalar) (ConvCmp (ConvXX' (FTKX [2,2,2,2] FTKScalar)) ConvSX))) (STKProduct (STKS [2,2,2,2] STKScalar) (STKS [2,2,2,2] STKScalar)) (let w45 = sreshape @[2,2,2,2,2,2,2] (stranspose @[1,2,3,4,0] (sreplicate @8 (sfromR dret))) in tpair (ssum @2 (ssum @2 (sdot1In (stranspose @[2,3,0,4,5,6,1] (sreplicate @2 (stranspose @[2,3,0,4,5,1] (sgather (stranspose @[4,2,0,3,1] (sgather (sappend (sconcrete (sreplicate [1,2,2,4] 0.0)) (stranspose @[3,1,2,0] (sappend (sconcrete (sreplicate [1,2,2,3] 0.0)) (stranspose @[0,2,3,1] (sgather (sslice (SNat @1) (SNat @3) (stranspose @[3,0,4,1,2] (sslice (SNat @1) (SNat @3) (stranspose @[3,1,2,4,0] (sfromVector (fromList [stranspose @[1,2,3,0] (sappend (sconcrete (sreplicate [1,2,2,4] 0.0)) (sappend (stranspose @[3,1,2,0] (sappend (sconcrete (sreplicate [1,2,2,2] 0.0)) (sappend (stranspose @[2,0,1] (sfromR (tproject2 u1))) (sconcrete (sreplicate [1,2,2,2] 0.0))))) (sconcrete (sreplicate [1,2,2,4] 0.0)))), sconcrete (sreplicate [2,2,4,4] 0.0)])))))) (\\[i36, i37] -> [i36, i37, ifH (notB (2 <=. i37) &&* notB (2 <=. i36)) 0 1])))))) (\\[i100, i102] -> [i100 + i102]))) (\\[i41, i42] -> [i41 + i42]))))) (stranspose @[2,3,1,4,5,6,0] w45)))) (stranspose @[3,1,2,0] (sslice (SNat @1) (SNat @2) (stranspose @[3,2,1,4,0] (sslice (SNat @1) (SNat @2) (stranspose @[3,2,1,0] (sappend (sconcrete (sreplicate [1,2,2,4,2] 0.0)) (stranspose @[1,3,4,0,2] (sappend (sconcrete (sreplicate [1,3,2,2,2] 0.0)) (sscatter (sslice (SNat @1) (SNat @3) (stranspose @[3,0,1,2] (sslice (SNat @1) (SNat @3) (sscatter (stranspose @[2,4,1,3,0] (sscatter (sdot1In (stranspose @[3,6,0,2,4,5,1] (sreplicate @2 (str (sreplicate @2 (str (sreplicate @2 (sfromR (tproject1 u1)))))))) (stranspose @[3,6,0,2,4,5,1] w45)) (\\[i46, i47] -> [i46 + i47]))) (\\[i48, i49] -> [i48 + i49]))))) (\\[i51, i52] -> [i51, i52, ifH (notB (2 <=. i52) &&* notB (2 <=. i51)) 0 1]))))))))) !$ [0]))"

conv2dPaddedLet
  :: forall target r. (ADReady target, GoodScalar r)
  => target (TKR 4 r) -> target (TKR 4 r) -> target (TKR 4 r)
conv2dPaddedLet arrK arrA =
  let [nImgs, nCinpA, nAh, nAw] = rshape arrA
      [nCoutK, nCinpK, nKh, nKw] = rshape arrK
      shAPadded = [nImgs, nCinpA, nAh + nKh, nAw + nKw]
      arrAPadded = rbuild @4 @0 @(TKScalar r) @target shAPadded $ \case
        [iImg, iCinp, iPh, iPw] ->
          ifH (iPh <. fromIntegral (nKh `div` 2)
               ||* iPw <. fromIntegral (nKw `div` 2)
               ||* iPh >=. fromIntegral (nAh + nKh `div` 2)
               ||* iPw >=. fromIntegral (nAw + nKw `div` 2))
              (rscalar 0)
              (arrA ! [ iImg
                      , iCinp
                      , iPh - fromIntegral (nKh `div` 2)
                      , iPw - fromIntegral (nKw `div` 2) ])
      nCinp = assert (nCinpA == nCinpK `blame` (nCinpA, nCinpK)) nCinpA
      shB = [nImgs, nCoutK, nAh, nAw]
      shK1 = [1, nCinp, nKh, nKw]
  in tlet arrAPadded $ \arrAPadded2 -> rbuild shB $ \case
    [iImg, iCout, iBh, iBw] ->
      let arrAt = slicezL shK1 arrAPadded2 [iImg, 0, iBh, iBw]
          arrKt = slicezL shK1 arrK [iCout, 0, 0, 0]
      in rdot0 arrAt arrKt
    _ -> error "conv2dPaddedLet: impossible pattern needlessly required"

testPaddedCNNOPPLet2 :: Assertion
testPaddedCNNOPPLet2 = do
  resetVarCounter
  let f :: AstTensor AstMethodLet FullSpan
                     (TKProduct (TKR 4 Double) (TKR 4 Double))
        -> AstTensor AstMethodLet FullSpan
                     (TKR 4 Double)
      f v = conv2dPaddedLet2 (tproject1 v) (tproject2 v)
      ftk = FTKProduct (FTKR (2 :$: 2 :$: 2 :$: 2 :$: ZSR) FTKScalar)
                       (FTKR (2 :$: 2 :$: 2 :$: 2 :$: ZSR) FTKScalar)
      (artifactRev, _) =
        revArtifactFromForwardPass
          UseIncomingCotangent (forwardPassByInterpretation f emptyEnv) ftk
  printArtifactPrimalPretty (simplifyArtifact artifactRev)
    @?= "\\u1 -> rfromS (ssum @8 (stranspose @[4,0,1,2,3] (sreshape @[2,2,2,2,8] (str (sreplicate @2 (stranspose @[1,2,3,0] (sreplicate @1 (stranspose @[2,3,0,4,5,1] (sgather (stranspose @[3,5,2,0,4,1] (sgather (stranspose @[1,2,0] (sreplicate @2 (sappend (sconcrete (sreplicate [1,2,2,4] 0.0)) (stranspose @[3,1,2,0] (sappend (sconcrete (sreplicate [1,2,2,3] 0.0)) (stranspose @[0,2,3,1] (sgather (sslice (SNat @1) (SNat @3) (stranspose @[3,0,4,1,2] (sslice (SNat @1) (SNat @3) (stranspose @[3,1,2,4,0] (sfromVector (fromList [stranspose @[1,2,3,0] (sappend (sconcrete (sreplicate [1,2,2,4] 0.0)) (sappend (stranspose @[3,1,2,0] (sappend (sconcrete (sreplicate [1,2,2,2] 0.0)) (sappend (stranspose @[2,0,1] (sfromR (tproject2 u1))) (sconcrete (sreplicate [1,2,2,2] 0.0))))) (sconcrete (sreplicate [1,2,2,4] 0.0)))), sconcrete (sreplicate [2,2,4,4] 0.0)])))))) (\\[i53, i54] -> [i53, i54, ifH (notB (2 <=. i54) &&* notB (2 <=. i53)) 0 1])))))))) (\\[i180, i182] -> [i180 + i182]))) (\\[i62, i63] -> [i62, i62 + i63])))))) * sreplicate @2 (str (sreplicate @2 (str (sreplicate @2 (str (sreplicate @1 (sfromR (tproject1 u1))))))))))))"
  printArtifactPrimalPretty artifactRev
    @?= "\\u1 -> let u52 = sappend (sconcrete (sreplicate [1,2,2,4] 0.0)) (sappend (stranspose @[3,1,2,0] (sappend (sconcrete (sreplicate [1,2,2,2] 0.0)) (sappend (stranspose @[2,0,1] (sfromR (tproject2 u1))) (sconcrete (sreplicate [1,2,2,2] 0.0))))) (sconcrete (sreplicate [1,2,2,4] 0.0))) ; u55 = stranspose @[1,2,0] (sappend (sconcrete (sreplicate [1,2,2,4] 0.0)) (stranspose @[3,1,2,0] (sappend (sconcrete (sreplicate [1,2,2,3] 0.0)) (stranspose @[0,2,3,1] (sgather (sslice (SNat @1) (SNat @3) (stranspose @[3,0,4,1,2] (sslice (SNat @1) (SNat @3) (stranspose @[3,1,2,4,0] (sfromVector (fromList [stranspose @[1,2,3,0] u52, sconcrete (sreplicate [2,2,4,4] 0.0)])))))) (\\[i53, i54] -> [i53, i54, ifH (notB (2 <=. i54) &&* notB (2 <=. i53)) 0 1])))))) ; w64 = str (sreplicate @2 (stranspose @[1,2,3,0] (sreplicate @1 (stranspose @[2,3,0,4,5,1] (sgather (stranspose @[3,5,2,0,4,1] (sgather (stranspose @[1,2,0] (sgather (stranspose @[1,5,0,2,3,4] (sreplicate @2 (stranspose @[3,0,1,2] (sgather (stranspose @[2,3,0,1] (sreplicate @2 u55)) (\\[i56, i57, i58] -> [i58, i56]))))) (\\[i59] -> [i59, i59]))) (\\[i60, i61] -> [i60, i60 + i61]))) (\\[i62, i63] -> [i62, i62 + i63])))))) in rfromS (ssum @8 (stranspose @[4,0,1,2,3] (sreshape @[2,2,2,2,8] (w64 * sreplicate @2 (str (sreplicate @2 (str (sreplicate @2 (str (sreplicate @1 (sfromR (tproject1 u1))))))))))))"
  printArtifactPretty (simplifyArtifact artifactRev)
    @?= "\\dret u1 -> tconvert (ConvT2 (ConvCmp (ConvXR STKScalar) (ConvCmp (ConvXX' (FTKX [2,2,2,2] FTKScalar)) ConvSX)) (ConvCmp (ConvXR STKScalar) (ConvCmp (ConvXX' (FTKX [2,2,2,2] FTKScalar)) ConvSX))) (STKProduct (STKS [2,2,2,2] STKScalar) (STKS [2,2,2,2] STKScalar)) (let w66 = sreshape @[2,2,2,2,1,2,2,2] (stranspose @[1,2,3,4,0] (sreplicate @8 (sfromR dret))) in tpair (ssum @2 (ssum @2 (sdot1In (stranspose @[5,0,1,4,3,2] (sgather (stranspose @[2,4,0,3,1] (sgather (str (sreplicate @2 (sappend (sconcrete (sreplicate [1,2,4,2] 0.0)) (stranspose @[3,2,0,1] (sappend (sconcrete (sreplicate [1,2,2,3] 0.0)) (stranspose @[0,2,3,1] (sgather (sslice (SNat @1) (SNat @3) (stranspose @[3,0,4,1,2] (sslice (SNat @1) (SNat @3) (stranspose @[3,1,2,4,0] (sfromVector (fromList [stranspose @[1,2,3,0] (sappend (sconcrete (sreplicate [1,2,2,4] 0.0)) (sappend (stranspose @[3,1,2,0] (sappend (sconcrete (sreplicate [1,2,2,2] 0.0)) (sappend (stranspose @[2,0,1] (sfromR (tproject2 u1))) (sconcrete (sreplicate [1,2,2,2] 0.0))))) (sconcrete (sreplicate [1,2,2,4] 0.0)))), sconcrete (sreplicate [2,2,4,4] 0.0)])))))) (\\[i53, i54] -> [i53, i54, ifH (notB (2 <=. i54) &&* notB (2 <=. i53)) 0 1])))))))) (\\[i253, i255] -> [i255 + i253]))) (\\[i222, i223, i229] -> [i222, i222 + i229]))) (stranspose @[4,2,3,1,5,6,7,0] w66 !$ [0])))) (stranspose @[3,1,2,0] (sslice (SNat @1) (SNat @2) (stranspose @[3,2,1,4,0] (sslice (SNat @1) (SNat @2) (stranspose @[3,2,1,0] (sappend (sconcrete (sreplicate [1,2,2,4,2] 0.0)) (stranspose @[1,3,4,0,2] (sappend (sconcrete (sreplicate [1,3,2,2,2] 0.0)) (sscatter (sslice (SNat @1) (SNat @3) (stranspose @[3,0,1,2] (sslice (SNat @1) (SNat @3) (ssum @2 (stranspose @[2,1,3,0] (sscatter (ssum @2 (stranspose @[2,3,4,5,0,1] (sscatter (stranspose @[2,0,1] (sscatter (stranspose @[3,5,2,0,4,1] (sscatter (sdot1In (sreplicate @2 (stranspose @[2,0,1,4,3] (sreplicate @2 (sreplicate @2 (stranspose @[3,2,1,0] (sfromR (tproject1 u1))))))) (stranspose @[4,3,7,0,2,5,6,1] w66 !$ [0])) (\\[i67, i68] -> [i67, i67 + i68]))) (\\[i69, i70] -> [i69, i69 + i70]))) (\\[i71] -> [i71, i71])))) (\\[i72, i73, i74] -> [i74, i72]))))))) (\\[i76, i77] -> [i76, i77, ifH (notB (2 <=. i77) &&* notB (2 <=. i76)) 0 1]))))))))) !$ [0]))"

conv2dPaddedLet2
  :: forall target r. (ADReady target, GoodScalar r)
  => target (TKR 4 r) -> target (TKR 4 r) -> target (TKR 4 r)
conv2dPaddedLet2 arrK arrA =
  let [nImgs, nCinpA, nAh, nAw] = rshape arrA
      [nCoutK, nCinpK, nKh, nKw] = rshape arrK
      shAPadded = [nImgs, nCinpA, nAh + nKh, nAw + nKw]
      arrAPadded = rbuild @4 @0 @(TKScalar r) @target shAPadded $ \case
        [iImg, iCinp, iPh, iPw] ->
          ifH (iPh <. fromIntegral (nKh `div` 2)
               ||* iPw <. fromIntegral (nKw `div` 2)
               ||* iPh >=. fromIntegral (nAh + nKh `div` 2)
               ||* iPw >=. fromIntegral (nAw + nKw `div` 2))
              (rscalar 0)
              (arrA ! [ iImg
                      , iCinp
                      , iPh - fromIntegral (nKh `div` 2)
                      , iPw - fromIntegral (nKw `div` 2) ])
      nCinp = assert (nCinpA == nCinpK `blame` (nCinpA, nCinpK)) nCinpA
      shB = [nImgs, nCoutK, nAh, nAw]
      shK1 = [1, nCinp, nKh, nKw]
  in rbuild shB $ \case
    [iImg, iCout, iBh, iBw] ->
      let arrAt = tlet arrAPadded $ \arrAPadded2 ->
                    slicezL shK1 arrAPadded2 [iImg, 0, iBh, iBw]
          arrKt = slicezL shK1 arrK [iCout, 0, 0, 0]
      in rdot0 arrAt arrKt
    _ -> error "conv2dPaddedLet2: impossible pattern needlessly required"

-- TODO: OOMs
_testPaddedCNNOPP2 :: Assertion
_testPaddedCNNOPP2 = do
  resetVarCounter
  let f :: AstTensor AstMethodLet FullSpan
                     (TKProduct (TKR 4 Double) (TKR 4 Double))
        -> AstTensor AstMethodLet FullSpan
                     (TKR 4 Double)
      f v = conv2dPadded2 (tproject1 v) (tproject2 v)
      ftk = FTKProduct (FTKR (2 :$: 2 :$: 2 :$: 2 :$: ZSR) FTKScalar)
                       (FTKR (2 :$: 2 :$: 2 :$: 2 :$: ZSR) FTKScalar)
      (artifactRev, _) =
        revArtifactFromForwardPass
          UseIncomingCotangent (forwardPassByInterpretation f emptyEnv) ftk
  printArtifactPrimalPretty (simplifyArtifact artifactRev)
    @?= "\\u1 -> rfromS (ssum @8 (stranspose @[4,0,1,2,3] (sreshape @[2,2,2,2,8] (str (sreplicate @2 (stranspose @[5,0,1,4,2,3] (sgather (sappend (sconcrete (sreplicate [1,4,2,2] 0.0)) (sappend (stranspose @[3,0,2,1] (sappend (sconcrete (sreplicate [1,2,2,2] 0.0)) (sappend (stranspose @[2,0,1] (sfromR (tproject2 u1))) (sconcrete (sreplicate [1,2,2,2] 0.0))))) (sconcrete (sreplicate [1,4,2,2] 0.0)))) (\\[i24, i25, i27, i28] -> [i25 + i28, i24 + i27])))) * sreplicate @2 (str (sreplicate @2 (str (sreplicate @2 (sfromR (tproject1 u1))))))))))"
  printArtifactPrimalPretty artifactRev
    @?= "\\u1 -> let w29 = str (sreplicate @2 (sgather (sappend (sconcrete (sreplicate [1,2,2,4] 0.0)) (sappend (stranspose @[3,1,2,0] (sappend (sconcrete (sreplicate [1,2,2,2] 0.0)) (sappend (stranspose @[2,0,1] (sfromR (tproject2 u1))) (sconcrete (sreplicate [1,2,2,2] 0.0))))) (sconcrete (sreplicate [1,2,2,4] 0.0)))) (\\[i23, i24, i25, i26, i27, i28] -> [i25 + i28, i23, i26, i24 + i27]))) in rfromS (ssum @8 (stranspose @[4,0,1,2,3] (sreshape @[2,2,2,2,8] (w29 * sreplicate @2 (str (sreplicate @2 (str (sreplicate @2 (sfromR (tproject1 u1))))))))))"
  printArtifactPretty (simplifyArtifact artifactRev)
    @?= "\\dret u1 -> tconvert (ConvT2 (ConvCmp (ConvXR STKScalar) (ConvCmp (ConvXX' (FTKX [2,2,2,2] FTKScalar)) ConvSX)) (ConvCmp (ConvXR STKScalar) (ConvCmp (ConvXX' (FTKX [2,2,2,2] FTKScalar)) ConvSX))) (STKProduct (STKS [2,2,2,2] STKScalar) (STKS [2,2,2,2] STKScalar)) (let w31 = sreshape @[2,2,2,2,2,2,2] (stranspose @[1,2,3,4,0] (sreplicate @8 (sfromR dret))) in tpair (ssum @2 (ssum @2 (sdot1In (stranspose @[2,3,0,4,5,6,1] (sreplicate @2 (stranspose @[5,0,1,4,2,3] (sgather (sappend (sconcrete (sreplicate [1,4,2,2] 0.0)) (sappend (stranspose @[3,0,2,1] (sappend (sconcrete (sreplicate [1,2,2,2] 0.0)) (sappend (stranspose @[2,0,1] (sfromR (tproject2 u1))) (sconcrete (sreplicate [1,2,2,2] 0.0))))) (sconcrete (sreplicate [1,4,2,2] 0.0)))) (\\[i24, i25, i27, i28] -> [i25 + i28, i24 + i27]))))) (stranspose @[2,3,1,4,5,6,0] w31)))) (stranspose @[1,2,0] (sslice (SNat @1) (SNat @2) (stranspose @[3,1,2,0] (sslice (SNat @1) (SNat @2) (sscatter (sdot1In (sreplicate @2 (sreplicate @2 (sreplicate @2 (stranspose @[1,2,3,0] (sfromR (tproject1 u1)))))) (stranspose @[0,2,3,4,5,6,1] w31)) (\\[i32, i33, i34, i35, i36, i37] -> [i34 + i37, i32, i35, i33 + i36])))))))"

conv2dPadded2
  :: forall target r. (ADReady target, GoodScalar r)
  => target (TKR 4 r) -> target (TKR 4 r) -> target (TKR 4 r)
conv2dPadded2 arrK arrA =
  let [nImgs, nCinpA, nAh, nAw] = rshape arrA
      [nCoutK, nCinpK, nKh, nKw] = rshape arrK
      shAPadded = [nImgs, nCinpA, nAh + nKh, nAw + nKw]
      arrAPadded = rbuild @4 @0 @(TKScalar r) @target shAPadded $ \case
        [iImg, iCinp, iPh, iPw] ->
               arrA ! [ iImg
                      , iCinp
                      , iPh - fromIntegral (nKh `div` 2)
                      , iPw - fromIntegral (nKw `div` 2) ]
      nCinp = assert (nCinpA == nCinpK `blame` (nCinpA, nCinpK)) nCinpA
      shB = [nImgs, nCoutK, nAh, nAw]
      shK1 = [1, nCinp, nKh, nKw]
  in rbuild shB $ \case
    [iImg, iCout, iBh, iBw] ->
      let arrAt = slicezL shK1 arrAPadded [iImg, 0, iBh, iBw]
          arrKt = slicezL shK1 arrK [iCout, 0, 0, 0]
      in rdot0 arrAt arrKt
    _ -> error "conv2dPadded2: impossible pattern needlessly required"


-- * Non-laborious CNN PP tests

-- Convolution differentiated wrt the kernel.
testCNNOPP0cW :: Assertion
testCNNOPP0cW = do
  resetVarCounter
  let ftk = FTKR (7 :$: 5 :$: 7 :$: 7 :$: ZSR) (FTKScalar @Double)
      varName = mkAstVarName ftk Nothing . intToAstVarId $ 100000000
      var = AstVar varName
      ftk2 = FTKR (5 :$: 5 :$: 5 :$: 5 :$: ZSR) (FTKScalar @Double)
      f = simplifyInline . flip conv2dUnpadded var
      env =
        extendEnv varName (dDnotShared (AstRaw var) (DeltaZero ftk)) emptyEnv
      (artifactRev, _) =
        revArtifactFromForwardPass
          UseIncomingCotangent (forwardPassByInterpretation f env) ftk2
  "\\u0 -> " ++ printArtifactPrimalPretty (simplifyArtifact artifactRev)
    @?= "\\u0 -> \\u1 -> rfromS (ssum @125 (stranspose @[4,0,1,2,3] (sreshape @[7,5,7,7,125] (str (sreplicate @5 (stranspose @[2,3,0,4,5,1] (sgather (stranspose @[4,2,0,3,1] (sgather (stranspose @[2,0,1] (sfromR u0)) (\\[i81, i83] -> [i81 + i83]))) (\\[i41, i42] -> [i41 + i42])))) * sreplicate @7 (str (sreplicate @7 (str (sreplicate @7 (sfromR u1)))))))))"
  "\\u0 -> " ++ printArtifactPrimalPretty artifactRev
    @?= "\\u0 -> \\u1 -> let w43 = str (sreplicate @5 (stranspose @[2,3,0,4,5,1] (sgather (stranspose @[4,2,0,3,1] (sgather (stranspose @[2,0,1] (sfromR u0)) (\\[i39, i40] -> [i39 + i40]))) (\\[i41, i42] -> [i41 + i42])))) in rfromS (ssum @125 (stranspose @[4,0,1,2,3] (sreshape @[7,5,7,7,125] (w43 * sreplicate @7 (str (sreplicate @7 (str (sreplicate @7 (sfromR u1)))))))))"
  "\\u0 -> " ++ printArtifactPretty artifactRev
    @?= "\\u0 -> \\dret u1 -> let w43 = str (sreplicate @5 (stranspose @[2,3,0,4,5,1] (sgather (stranspose @[4,2,0,3,1] (sgather (stranspose @[2,0,1] (sfromR u0)) (\\[i39, i40] -> [i39 + i40]))) (\\[i41, i42] -> [i41 + i42])))) ; w45 = sreshape @[7,5,7,7,5,5,5] (stranspose @[1,2,3,4,0] (sreplicate @125 (sfromR dret))) in rfromS (ssum @7 (str (ssum @7 (str (ssum @7 (w43 * w45))))))"
  "\\u0 -> " ++ printArtifactPretty (simplifyArtifact artifactRev)
    @?= "\\u0 -> \\dret u1 -> rfromS (ssum @7 (ssum @7 (sdot1In (stranspose @[2,3,0,4,5,6,1] (sreplicate @5 (stranspose @[2,3,0,4,5,1] (sgather (stranspose @[4,2,0,3,1] (sgather (stranspose @[2,0,1] (sfromR u0)) (\\[i98, i100] -> [i98 + i100]))) (\\[i41, i42] -> [i41 + i42]))))) (stranspose @[2,3,1,4,5,6,0] (sreshape @[7,5,7,7,5,5,5] (stranspose @[1,2,3,4,0] (sreplicate @125 (sfromR dret))))))))"

-- Convolution differentiated wrt the data.
testCNNOPP0bW :: Assertion
testCNNOPP0bW = do
  resetVarCounter
  let ftk = FTKR (5 :$: 5 :$: 5 :$: 5 :$: ZSR) (FTKScalar @Double)
      varName = mkAstVarName ftk Nothing . intToAstVarId $ 100000000
      var = AstVar varName
      ftk2 = FTKR (7 :$: 5 :$: 7 :$: 7 :$: ZSR) (FTKScalar @Double)
      f = simplifyInline . conv2dUnpadded var
      env =
        extendEnv varName (dDnotShared (AstRaw var) (DeltaZero ftk)) emptyEnv
      (artifactRev, _) =
        revArtifactFromForwardPass
          UseIncomingCotangent (forwardPassByInterpretation f env) ftk2
  "\\u0 -> " ++ printArtifactPrimalPretty (simplifyArtifact artifactRev)
    @?= "\\u0 -> \\u1 -> rfromS (ssum @125 (stranspose @[4,0,1,2,3] (sreshape @[7,5,7,7,125] (str (sreplicate @5 (stranspose @[2,3,0,4,5,1] (sgather (stranspose @[4,2,0,3,1] (sgather (stranspose @[2,0,1] (sfromR u1)) (\\[i64, i66] -> [i64 + i66]))) (\\[i41, i42] -> [i41 + i42])))) * sreplicate @7 (str (sreplicate @7 (str (sreplicate @7 (sfromR u0)))))))))"
  "\\u0 -> " ++ printArtifactPrimalPretty artifactRev
    @?= "\\u0 -> \\u1 -> let w43 = str (sreplicate @5 (stranspose @[2,3,0,4,5,1] (sgather (stranspose @[4,2,0,3,1] (sgather (stranspose @[2,0,1] (sfromR u1)) (\\[i39, i40] -> [i39 + i40]))) (\\[i41, i42] -> [i41 + i42])))) in rfromS (ssum @125 (stranspose @[4,0,1,2,3] (sreshape @[7,5,7,7,125] (w43 * sreplicate @7 (str (sreplicate @7 (str (sreplicate @7 (sfromR u0)))))))))"
  "\\u0 -> " ++ printArtifactPretty artifactRev
    @?= "\\u0 -> \\dret u1 -> let w45 = sreshape @[7,5,7,7,5,5,5] (stranspose @[1,2,3,4,0] (sreplicate @125 (sfromR dret))) in rfromS (stranspose @[1,2,0] (sscatter (stranspose @[2,4,1,3,0] (sscatter (stranspose @[2,5,0,1,3,4] (ssum @5 (str (sreplicate @7 (str (sreplicate @7 (str (sreplicate @7 (sfromR u0))))) * w45)))) (\\[i46, i47] -> [i46 + i47]))) (\\[i48, i49] -> [i48 + i49])))"
  "\\u0 -> " ++ printArtifactPretty (simplifyArtifact artifactRev)
    @?= "\\u0 -> \\dret u1 -> rfromS (stranspose @[1,2,0] (sscatter (stranspose @[2,4,1,3,0] (sscatter (sdot1In (stranspose @[3,6,0,2,4,5,1] (sreplicate @7 (str (sreplicate @7 (str (sreplicate @7 (sfromR u0))))))) (stranspose @[3,6,0,2,4,5,1] (sreshape @[7,5,7,7,5,5,5] (stranspose @[1,2,3,4,0] (sreplicate @125 (sfromR dret)))))) (\\[i46, i47] -> [i46 + i47]))) (\\[i48, i49] -> [i48 + i49])))"

testCNNOPP1bW :: Assertion
testCNNOPP1bW = do
  resetVarCounter
  let f :: AstTensor AstMethodLet FullSpan
                     (TKProduct (TKR 4 Double) (TKR 4 Double))
        -> AstTensor AstMethodLet FullSpan
                     (TKR 4 Double)
      f v = simplifyInline $ conv2dUnpadded (tproject1 v) (tproject2 v)
      ftk = FTKProduct (FTKR (7 :$: 7 :$: 7 :$: 7 :$: ZSR) FTKScalar)
                       (FTKR (7 :$: 7 :$: 7 :$: 7 :$: ZSR) FTKScalar)
      (artifactRev, _) =
        revArtifactFromForwardPass
          UseIncomingCotangent (forwardPassByInterpretation f emptyEnv) ftk
  printArtifactPrimalPretty (simplifyArtifact artifactRev)
    @?= "\\u1 -> rfromS (ssum @343 (stranspose @[4,0,1,2,3] (sreshape @[7,7,7,7,343] (str (sreplicate @7 (stranspose @[2,3,0,4,5,1] (sgather (stranspose @[4,2,0,3,1] (sgather (stranspose @[2,0,1] (sfromR (tproject2 u1))) (\\[i81, i83] -> [i81 + i83]))) (\\[i41, i42] -> [i41 + i42])))) * sreplicate @7 (str (sreplicate @7 (str (sreplicate @7 (sfromR (tproject1 u1))))))))))"
  printArtifactPrimalPretty artifactRev
    @?= "\\u1 -> let w43 = str (sreplicate @7 (stranspose @[2,3,0,4,5,1] (sgather (stranspose @[4,2,0,3,1] (sgather (stranspose @[2,0,1] (sfromR (tproject2 u1))) (\\[i39, i40] -> [i39 + i40]))) (\\[i41, i42] -> [i41 + i42])))) in rfromS (ssum @343 (stranspose @[4,0,1,2,3] (sreshape @[7,7,7,7,343] (w43 * sreplicate @7 (str (sreplicate @7 (str (sreplicate @7 (sfromR (tproject1 u1))))))))))"
  printArtifactPretty artifactRev
    @?= "\\dret u1 -> let w43 = str (sreplicate @7 (stranspose @[2,3,0,4,5,1] (sgather (stranspose @[4,2,0,3,1] (sgather (stranspose @[2,0,1] (sfromR (tproject2 u1))) (\\[i39, i40] -> [i39 + i40]))) (\\[i41, i42] -> [i41 + i42])))) ; w45 = sreshape @[7,7,7,7,7,7,7] (stranspose @[1,2,3,4,0] (sreplicate @343 (sfromR dret))) in tpair (rfromS (ssum @7 (str (ssum @7 (str (ssum @7 (w43 * w45))))))) (rfromS (stranspose @[1,2,0] (sscatter (stranspose @[2,4,1,3,0] (sscatter (stranspose @[2,5,0,1,3,4] (ssum @7 (str (sreplicate @7 (str (sreplicate @7 (str (sreplicate @7 (sfromR (tproject1 u1)))))) * w45)))) (\\[i46, i47] -> [i46 + i47]))) (\\[i48, i49] -> [i48 + i49]))))"
  printArtifactPretty (simplifyArtifact artifactRev)
    @?= "\\dret u1 -> tconvert (ConvT2 (ConvCmp (ConvXR STKScalar) (ConvCmp (ConvXX' (FTKX [7,7,7,7] FTKScalar)) ConvSX)) (ConvCmp (ConvXR STKScalar) (ConvCmp (ConvXX' (FTKX [7,7,7,7] FTKScalar)) ConvSX))) (STKProduct (STKS [7,7,7,7] STKScalar) (STKS [7,7,7,7] STKScalar)) (let w45 = sreshape @[7,7,7,7,7,7,7] (stranspose @[1,2,3,4,0] (sreplicate @343 (sfromR dret))) in tpair (ssum @7 (ssum @7 (sdot1In (stranspose @[2,3,0,4,5,6,1] (sreplicate @7 (stranspose @[2,3,0,4,5,1] (sgather (stranspose @[4,2,0,3,1] (sgather (stranspose @[2,0,1] (sfromR (tproject2 u1))) (\\[i98, i100] -> [i98 + i100]))) (\\[i41, i42] -> [i41 + i42]))))) (stranspose @[2,3,1,4,5,6,0] w45)))) (stranspose @[1,2,0] (sscatter (stranspose @[2,4,1,3,0] (sscatter (sdot1In (stranspose @[3,6,0,2,4,5,1] (sreplicate @7 (str (sreplicate @7 (str (sreplicate @7 (sfromR (tproject1 u1)))))))) (stranspose @[3,6,0,2,4,5,1] w45)) (\\[i46, i47] -> [i46 + i47]))) (\\[i48, i49] -> [i48 + i49]))))"

testCNNOPP4bW :: Assertion
testCNNOPP4bW = do
  resetVarCounter
  let !artifactRev = revArtifactAdapt UseIncomingCotangent (maxPool2dUnpadded 4 2) (FTKR [7, 7, 7, 7] (FTKScalar @Double))
      !artSimp = simplifyArtifact artifactRev
  let ftk1 = FTKR (7 :$: 7 :$: 7 :$: 7 :$: ZSR) (FTKScalar @Double)
      ftkDt = FTKR (7 :$: 7 :$: 3 :$: 3 :$: ZSR) (FTKScalar @Double)
      env = extendEnv (artVarDtRev artSimp)
                      (tconcrete ftkDt (treplTarget 7 ftkDt))
            $ extendEnv (artVarDomainRev artSimp)
                        (tconcrete ftk1 (treplTarget 42 ftk1)) emptyEnv
  interpretAstPrimal @Concrete env (artPrimalRev artifactRev)
    @?= interpretAstPrimal @Concrete env (artPrimalRev artSimp)
  interpretAstPrimal @Concrete env (artDerivativeRev artifactRev)
    @?= interpretAstPrimal @Concrete env (artDerivativeRev artSimp)
  printArtifactPrimalPretty (simplifyArtifact artifactRev)
    @?= "\\u1 -> rfromS (let w47 = sreshape @[7,7,3,3,16] (stranspose @[2,3,4,0,5,1] (sgather (stranspose @[4,2,3,0,1] (sgather (stranspose @[2,0,1] (sfromR u1)) (\\[i137, i138] -> [2 * i137 + i138]))) (\\[i45, i46] -> [2 * i45 + i46]))) in sgather w47 (\\[i48, i49, i50, i51] -> [i48, i49, i50, i51, kfromS (smaxIndex (w47 !$ [i48, i49, i50, i51]))]))"
  printArtifactPrimalPretty artifactRev
    @?= "\\u1 -> let w47 = sreshape @[7,7,3,3,16] (stranspose @[2,3,4,0,5,1] (sgather (stranspose @[4,2,3,0,1] (sgather (stranspose @[2,0,1] (sfromR u1)) (\\[i43, i44] -> [2 * i43 + i44]))) (\\[i45, i46] -> [2 * i45 + i46]))) in rfromS (sgather w47 (\\[i48, i49, i50, i51] -> [i48, i49, i50, i51, kfromS (smaxIndex (w47 !$ [i48, i49, i50, i51]))]))"
  printArtifactPretty artifactRev
    @?= "\\dret u1 -> let w47 = sreshape @[7,7,3,3,16] (stranspose @[2,3,4,0,5,1] (sgather (stranspose @[4,2,3,0,1] (sgather (stranspose @[2,0,1] (sfromR u1)) (\\[i43, i44] -> [2 * i43 + i44]))) (\\[i45, i46] -> [2 * i45 + i46]))) in rfromS (stranspose @[1,2,0] (sscatter (stranspose @[3,4,1,2,0] (sscatter (stranspose @[3,5,0,1,2,4] (sreshape @[7,7,3,3,4,4] (sscatter (sfromR dret) (\\[i53, i54, i55, i56] -> [i53, i54, i55, i56, kfromS (smaxIndex (w47 !$ [i53, i54, i55, i56]))])))) (\\[i57, i58] -> [2 * i57 + i58]))) (\\[i59, i60] -> [2 * i59 + i60])))"
  -- The remH comes from the indexing of reshape rule and it looks terrible,
  -- but w42 looks even worse, depending on available primitives,
  -- so the rule is probably fine.
  printArtifactPretty (simplifyArtifact artifactRev)
    @?= "\\dret u1 -> rfromS (stranspose @[1,2,0] (sscatter (stranspose @[3,4,1,2,0] (sscatter (stranspose @[3,5,0,1,2,4] (sreshape @[7,7,3,3,4,4] (sscatter (sfromR dret) (\\[i53, i54, i55, i56] -> [i53, i54, i55, i56, kfromS (smaxIndex (sgather (sgather (stranspose @[4,2,3,0,1] (sgather (stranspose @[2,0,1] (sfromR u1)) (\\[i159, i160] -> [2 * i159 + i160]))) (\\[i45, i46] -> [2 * i45 + i46])) (\\[i151] -> [remH (quotH ((((1008 * i53 + 144 * i54) + 48 * i55) + 16 * i56) + i151) 16) 3, remH ((((1008 * i53 + 144 * i54) + 48 * i55) + 16 * i56) + i151) 4, remH (quotH ((((1008 * i53 + 144 * i54) + 48 * i55) + 16 * i56) + i151) 1008) 7, remH (quotH ((((1008 * i53 + 144 * i54) + 48 * i55) + 16 * i56) + i151) 144) 7, remH (quotH ((((1008 * i53 + 144 * i54) + 48 * i55) + 16 * i56) + i151) 48) 3, remH (quotH ((((1008 * i53 + 144 * i54) + 48 * i55) + 16 * i56) + i151) 4) 4])))])))) (\\[i57, i58] -> [2 * i57 + i58]))) (\\[i59, i60] -> [2 * i59 + i60])))"
  printAstPretty (simplifyInlineContractNoExpand $ artDerivativeRev artifactRev)
    @?= "rfromS (stranspose @[1,2,0] (sscatter (stranspose @[3,4,1,2,0] (sscatter (stranspose @[3,5,0,1,2,4] (sreshape @[7,7,3,3,4,4] (sscatter (sfromR u52) (\\[i53, i54, i55, i56] -> [i53, i54, i55, i56, kfromS (smaxIndex (sreshape @[7,7,3,3,16] (stranspose @[2,3,4,0,5,1] (sgather (stranspose @[4,2,3,0,1] (sgather (stranspose @[2,0,1] (sfromR u1)) (\\[i43, i44] -> [2 * i43 + i44]))) (\\[i45, i46] -> [2 * i45 + i46]))) !$ [i53, i54, i55, i56]))])))) (\\[i57, i58] -> [2 * i57 + i58]))) (\\[i59, i60] -> [2 * i59 + i60])))"

testCNNOPP4bD :: Assertion
testCNNOPP4bD = do
  resetVarCounter
  setTotalSharing True
  let !artifactRev = revArtifactAdapt UseIncomingCotangent (maxPool2dUnpadded 4 2) (FTKR [7, 7, 7, 7] (FTKScalar @Double))
      !artSimp = simplifyArtifact artifactRev
  setTotalSharing False
  let ftk1 = FTKR (7 :$: 7 :$: 7 :$: 7 :$: ZSR) (FTKScalar @Double)
      ftkDt = FTKR (7 :$: 7 :$: 3 :$: 3 :$: ZSR) (FTKScalar @Double)
      env = extendEnv (artVarDtRev artSimp)
                      (tconcrete ftkDt (treplTarget 7 ftkDt))
            $ extendEnv (artVarDomainRev artSimp)
                        (tconcrete ftk1 (treplTarget 42 ftk1)) emptyEnv
  interpretAstPrimal @Concrete env (artPrimalRev artifactRev)
    @?= interpretAstPrimal @Concrete env (artPrimalRev artSimp)
  interpretAstPrimal @Concrete env (artDerivativeRev artifactRev)
    @?= interpretAstPrimal @Concrete env (artDerivativeRev artSimp)
  printArtifactPrimalPretty artSimp
    @?= "\\u1 -> rfromS (let w47 = sreshape @[7,7,3,3,16] (stranspose @[2,3,4,0,5,1] (sgather (stranspose @[4,2,3,0,1] (sgather (stranspose @[2,0,1] (sfromR u1)) (\\[i98, i99] -> [2 * i98 + i99]))) (\\[i45, i46] -> [2 * i45 + i46]))) in sgather w47 (\\[i48, i49, i50, i51] -> [i48, i49, i50, i51, kfromS (smaxIndex (w47 !$ [i48, i49, i50, i51]))]))"
  printArtifactPrimalPretty artifactRev
    @?= "\\u1 -> let w47 = sreshape @[7,7,3,3,16] (stranspose @[2,3,4,0,5,1] (sgather (stranspose @[4,2,3,0,1] (sgather (stranspose @[2,0,1] (sfromR u1)) (\\[i43, i44] -> [2 * i43 + i44]))) (\\[i45, i46] -> [2 * i45 + i46]))) in rfromS (sgather w47 (\\[i48, i49, i50, i51] -> [i48, i49, i50, i51, kfromS (smaxIndex (w47 !$ [i48, i49, i50, i51]))]))"
  printArtifactPretty artifactRev
    @?= "\\dret u1 -> let w47 = sreshape @[7,7,3,3,16] (stranspose @[2,3,4,0,5,1] (sgather (stranspose @[4,2,3,0,1] (sgather (stranspose @[2,0,1] (sfromR u1)) (\\[i43, i44] -> [2 * i43 + i44]))) (\\[i45, i46] -> [2 * i45 + i46]))) in rfromS (stranspose @[1,2,0] (sscatter (stranspose @[3,4,1,2,0] (sscatter (stranspose @[3,5,0,1,2,4] (sreshape @[7,7,3,3,4,4] (sscatter (sfromR dret) (\\[i53, i54, i55, i56] -> [i53, i54, i55, i56, kfromS (smaxIndex (w47 !$ [i53, i54, i55, i56]))])))) (\\[i57, i58] -> [2 * i57 + i58]))) (\\[i59, i60] -> [2 * i59 + i60])))"
  printArtifactPretty artSimp
    @?= "\\dret u1 -> rfromS (stranspose @[1,2,0] (sscatter (stranspose @[3,4,1,2,0] (sscatter (stranspose @[3,5,0,1,2,4] (sreshape @[7,7,3,3,4,4] (sscatter (sfromR dret) (\\[i53, i54, i55, i56] -> [i53, i54, i55, i56, kfromS (smaxIndex (sgather (sgather (stranspose @[4,2,3,0,1] (sgather (stranspose @[2,0,1] (sfromR u1)) (\\[i81, i82] -> [2 * i81 + i82]))) (\\[i45, i46] -> [2 * i45 + i46])) (\\[i73] -> [remH (quotH ((((1008 * i53 + 144 * i54) + 48 * i55) + 16 * i56) + i73) 16) 3, remH ((((1008 * i53 + 144 * i54) + 48 * i55) + 16 * i56) + i73) 4, remH (quotH ((((1008 * i53 + 144 * i54) + 48 * i55) + 16 * i56) + i73) 1008) 7, remH (quotH ((((1008 * i53 + 144 * i54) + 48 * i55) + 16 * i56) + i73) 144) 7, remH (quotH ((((1008 * i53 + 144 * i54) + 48 * i55) + 16 * i56) + i73) 48) 3, remH (quotH ((((1008 * i53 + 144 * i54) + 48 * i55) + 16 * i56) + i73) 4) 4])))])))) (\\[i57, i58] -> [2 * i57 + i58]))) (\\[i59, i60] -> [2 * i59 + i60])))"
  printAstPretty (simplifyInlineContractNoExpand $ artDerivativeRev artifactRev)
    @?= "rfromS (stranspose @[1,2,0] (sscatter (stranspose @[3,4,1,2,0] (sscatter (stranspose @[3,5,0,1,2,4] (sreshape @[7,7,3,3,4,4] (sscatter (sfromR u52) (\\[i53, i54, i55, i56] -> [i53, i54, i55, i56, kfromS (smaxIndex (sreshape @[7,7,3,3,16] (stranspose @[2,3,4,0,5,1] (sgather (stranspose @[4,2,3,0,1] (sgather (stranspose @[2,0,1] (sfromR u1)) (\\[i43, i44] -> [2 * i43 + i44]))) (\\[i45, i46] -> [2 * i45 + i46]))) !$ [i53, i54, i55, i56]))])))) (\\[i57, i58] -> [2 * i57 + i58]))) (\\[i59, i60] -> [2 * i59 + i60])))"

testCNNOPP5aW :: Assertion
testCNNOPP5aW = do
  resetVarCounter
  let artifactRev = revArtifactAdapt UseIncomingCotangent (maxPool2dUnpadded 4 2 . conv2dC) (FTKR [7, 2, 7, 7] (FTKScalar @Double))
  printArtifactPrimalPretty (simplifyArtifact artifactRev)
    @?= "\\u1 -> rfromS (let t49 = sreshape @[2,7,16] (ssum @98 (stranspose @[4,1,2,3,0] (sappend (stranspose @[3,1,2,0] (sappend (stranspose @[2,0,1] (sreshape @[2,7,2,2,98] (str (sreplicate @7 (stranspose @[1,2,3,0] (sreplicate @1 (stranspose @[2,3,0,4,5,1] (sgather (stranspose @[4,2,0,3,1] (sgather (sconcrete (sfromListLinear [2,2,2,2] [5.0,2.0,-2.0,0.0,13.1,9.0,582934.0,2.99432,6.0,1.0,0.1,-0.2,8.0,-4.0,-335.0,26.0])) (\\[i109, i111] -> [i109 + i111]))) (\\[i46, i47] -> [i46 + i47])))))) * sreplicate @2 (str (sreplicate @2 (str (sreplicate @2 (str (sreplicate @1 (sfromR u1)))))))))) (sconcrete (sreplicate [2,2,7,2,98] 0.0)))) (sconcrete (sreplicate [2,2,7,4,98] 0.0))))) in stranspose @[1,2,0] (sreplicate @1 (stranspose @[1,2,0] (sreplicate @1 (sgather t49 (\\[i50, i51] -> [i50, i51, kfromS (smaxIndex (t49 !$ [i50, i51]))]))))))"
  printArtifactPrimalPretty artifactRev
    @?= "\\u1 -> let w48 = str (sreplicate @7 (stranspose @[1,2,3,0] (sreplicate @1 (stranspose @[2,3,0,4,5,1] (sgather (stranspose @[4,2,0,3,1] (sgather (sconcrete (sfromListLinear [2,2,2,2] [5.0,2.0,-2.0,0.0,13.1,9.0,582934.0,2.99432,6.0,1.0,0.1,-0.2,8.0,-4.0,-335.0,26.0])) (\\[i44, i45] -> [i44 + i45]))) (\\[i46, i47] -> [i46 + i47])))))) ; t49 = sreshape @[2,7,16] (ssum @98 (stranspose @[4,1,2,3,0] (sappend (stranspose @[3,1,2,0] (sappend (stranspose @[2,0,1] (sreshape @[2,7,2,2,98] (w48 * sreplicate @2 (str (sreplicate @2 (str (sreplicate @2 (str (sreplicate @1 (sfromR u1)))))))))) (sconcrete (sreplicate [2,2,7,2,98] 0.0)))) (sconcrete (sreplicate [2,2,7,4,98] 0.0))))) in rfromS (stranspose @[1,2,0] (sreplicate @1 (stranspose @[1,2,0] (sreplicate @1 (sgather t49 (\\[i50, i51] -> [i50, i51, kfromS (smaxIndex (t49 !$ [i50, i51]))]))))))"
  printArtifactPretty artifactRev
    @?= "\\dret u1 -> let w48 = str (sreplicate @7 (stranspose @[1,2,3,0] (sreplicate @1 (stranspose @[2,3,0,4,5,1] (sgather (stranspose @[4,2,0,3,1] (sgather (sconcrete (sfromListLinear [2,2,2,2] [5.0,2.0,-2.0,0.0,13.1,9.0,582934.0,2.99432,6.0,1.0,0.1,-0.2,8.0,-4.0,-335.0,26.0])) (\\[i44, i45] -> [i44 + i45]))) (\\[i46, i47] -> [i46 + i47])))))) ; t49 = sreshape @[2,7,16] (ssum @98 (stranspose @[4,1,2,3,0] (sappend (stranspose @[3,1,2,0] (sappend (stranspose @[2,0,1] (sreshape @[2,7,2,2,98] (w48 * sreplicate @2 (str (sreplicate @2 (str (sreplicate @2 (str (sreplicate @1 (sfromR u1)))))))))) (sconcrete (sreplicate [2,2,7,2,98] 0.0)))) (sconcrete (sreplicate [2,2,7,4,98] 0.0))))) ; w55 = stranspose @[4,1,2,3,0] (sreplicate @98 (sreshape @[2,7,4,4] (sscatter (ssum @1 (stranspose @[2,0,1] (ssum @1 (stranspose @[2,0,1] (sfromR dret))))) (\\[i53, i54] -> [i53, i54, kfromS (smaxIndex (t49 !$ [i53, i54]))])))) in rfromS (ssum @1 (str (ssum @2 (str (ssum @2 (str (ssum @2 (w48 * sreshape @[2,7,2,2,1,2,7,7] (stranspose @[1,2,0] (sslice (SNat @0) (SNat @2) (stranspose @[3,1,2,0] (sslice (SNat @0) (SNat @2) w55))))))))))))"
  printArtifactPretty (simplifyArtifact artifactRev)
    @?= "\\dret u1 -> rfromS (let w48 = sgather (stranspose @[4,2,0,3,1] (sgather (sconcrete (sfromListLinear [2,2,2,2] [5.0,2.0,-2.0,0.0,13.1,9.0,582934.0,2.99432,6.0,1.0,0.1,-0.2,8.0,-4.0,-335.0,26.0])) (\\[i155, i157] -> [i155 + i157]))) (\\[i46, i47] -> [i46 + i47]) in ssum @2 (ssum @2 (sdot1In (stranspose @[1,6,0,4,5,3,2] (sreplicate @7 (stranspose @[3,2,1,4,5,0] w48))) (stranspose @[4,2,3,1,5,6,7,0] (sreshape @[2,7,2,2,1,2,7,7] (stranspose @[1,2,0] (sslice (SNat @0) (SNat @2) (stranspose @[3,1,2,0] (sslice (SNat @0) (SNat @2) (stranspose @[4,1,2,3,0] (sreplicate @98 (sreshape @[2,7,4,4] (sscatter (stranspose @[2,3,0,1] (sfromR dret) !$ [0, 0]) (\\[i53, i54] -> [i53, i54, kfromS (smaxIndex (ssum @98 (str (sgather (sappend (stranspose @[3,1,2,0] (sappend (stranspose @[2,0,1] (sreshape @[2,7,2,2,98] (str (sreplicate @7 (stranspose @[1,2,3,0] (sreplicate @1 (stranspose @[2,3,0,4,5,1] w48)))) * sreplicate @2 (str (sreplicate @2 (str (sreplicate @2 (str (sreplicate @1 (sfromR u1)))))))))) (sconcrete (sreplicate [2,2,7,2,98] 0.0)))) (sconcrete (sreplicate [2,2,7,4,98] 0.0))) (\\[i129] -> [remH ((112 * i53 + 16 * i54) + i129) 4, remH (quotH ((112 * i53 + 16 * i54) + i129) 112) 2, remH (quotH ((112 * i53 + 16 * i54) + i129) 16) 7, remH (quotH ((112 * i53 + 16 * i54) + i129) 4) 4])))))])))))))))) !$ [0]))))"
  printAstPretty (simplifyInlineContractNoExpand $ artDerivativeRev artifactRev)
    @?= "rfromS (let w48 = sgather (stranspose @[4,2,0,3,1] (sgather (sconcrete (sfromListLinear [2,2,2,2] [5.0,2.0,-2.0,0.0,13.1,9.0,582934.0,2.99432,6.0,1.0,0.1,-0.2,8.0,-4.0,-335.0,26.0])) (\\[i44, i45] -> [i44 + i45]))) (\\[i46, i47] -> [i46 + i47]) in ssum @2 (ssum @2 (sdot1In (stranspose @[4,2,3,0,5,6,7,1] (sreplicate @7 (stranspose @[1,2,3,0] (sreplicate @1 (stranspose @[2,3,0,4,5,1] w48)))) !$ [0]) (stranspose @[4,2,3,1,5,6,7,0] (sreshape @[2,7,2,2,1,2,7,7] (stranspose @[1,2,0] (sslice (SNat @0) (SNat @2) (stranspose @[3,1,2,0] (sslice (SNat @0) (SNat @2) (stranspose @[4,1,2,3,0] (sreplicate @98 (sreshape @[2,7,4,4] (sscatter (stranspose @[2,3,0,1] (sfromR u52) !$ [0, 0]) (\\[i53, i54] -> [i53, i54, kfromS (smaxIndex (sreshape @[2,7,16] (ssum @98 (stranspose @[4,1,2,3,0] (sappend (stranspose @[3,1,2,0] (sappend (stranspose @[2,0,1] (sreshape @[2,7,2,2,98] (str (sreplicate @7 (stranspose @[1,2,3,0] (sreplicate @1 (stranspose @[2,3,0,4,5,1] w48)))) * sreplicate @2 (str (sreplicate @2 (str (sreplicate @2 (str (sreplicate @1 (sfromR u1)))))))))) (sconcrete (sreplicate [2,2,7,2,98] 0.0)))) (sconcrete (sreplicate [2,2,7,4,98] 0.0))))) !$ [i53, i54]))])))))))))) !$ [0]))))"

testCNNOPP5bW :: Assertion
testCNNOPP5bW = do
  resetVarCounter
  let artifactRev = revArtifactAdapt UseIncomingCotangent (maxPool2dUnpadded 4 2 . relu) (FTKR [7, 2, 7, 7] (FTKScalar @Double))
  printArtifactPrimalPretty (simplifyArtifact artifactRev)
    @?= "\\u1 -> rfromS (let w75 = sreshape @[7,2,3,3,16] (stranspose @[4,5,0,1,2,3] (sgather (sgather (sconcrete (sfromListLinear [2] [0.0,1.0])) (\\[i61, i62, i63, i64] -> [ifH (sscalar -0.0 <=. negate (sfromR u1 !$ [i63, i64, i61, i62])) 0 1])) (\\[i65, i66, i67, i68] -> [kfromS (sconcrete (sfromListLinear [3,4] [0,1,2,3,2,3,4,5,4,5,6,7]) !$ [i65, i67]), kfromS (sconcrete (sfromListLinear [3,4] [0,1,2,3,2,3,4,5,4,5,6,7]) !$ [i66, i68])])) * stranspose @[4,5,0,1,2,3] (sgather (stranspose @[2,3,0,1] (sfromR u1)) (\\[i69, i70, i71, i72] -> [kfromS (sconcrete (sfromListLinear [3,4] [0,1,2,3,2,3,4,5,4,5,6,7]) !$ [i69, i71]), kfromS (sconcrete (sfromListLinear [3,4] [0,1,2,3,2,3,4,5,4,5,6,7]) !$ [i70, i72])]))) in sgather w75 (\\[i76, i77, i78, i79] -> [i76, i77, i78, i79, kfromS (smaxIndex (w75 !$ [i76, i77, i78, i79]))]))"
  printArtifactPrimalPretty artifactRev
    @?= "\\u1 -> let m59 = str (sreplicate @4 (sconcrete (sreplicate [3] 2) * siota (SNat @3))) + sreplicate @3 (siota (SNat @4)) ; m60 = str (sreplicate @4 (sconcrete (sreplicate [3] 2) * siota (SNat @3))) + sreplicate @3 (siota (SNat @4)) ; w73 = stranspose @[4,5,0,1,2,3] (sgather (sgather (sconcrete (sfromListLinear [2] [0.0,1.0])) (\\[i61, i62, i63, i64] -> [ifH (sscalar -0.0 <=. negate (sfromR u1 !$ [i63, i64, i61, i62])) 0 1])) (\\[i65, i66, i67, i68] -> [kfromS (m59 !$ [i65, i67]), kfromS (m60 !$ [i66, i68])])) ; w74 = stranspose @[4,5,0,1,2,3] (sgather (stranspose @[2,3,0,1] (sfromR u1)) (\\[i69, i70, i71, i72] -> [kfromS (m59 !$ [i69, i71]), kfromS (m60 !$ [i70, i72])])) ; w75 = sreshape @[7,2,3,3,16] (w73 * w74) in rfromS (sgather w75 (\\[i76, i77, i78, i79] -> [i76, i77, i78, i79, kfromS (smaxIndex (w75 !$ [i76, i77, i78, i79]))]))"
  printArtifactPretty artifactRev
    @?= "\\dret u1 -> let m59 = str (sreplicate @4 (sconcrete (sreplicate [3] 2) * siota (SNat @3))) + sreplicate @3 (siota (SNat @4)) ; m60 = str (sreplicate @4 (sconcrete (sreplicate [3] 2) * siota (SNat @3))) + sreplicate @3 (siota (SNat @4)) ; w73 = stranspose @[4,5,0,1,2,3] (sgather (sgather (sconcrete (sfromListLinear [2] [0.0,1.0])) (\\[i61, i62, i63, i64] -> [ifH (sscalar -0.0 <=. negate (sfromR u1 !$ [i63, i64, i61, i62])) 0 1])) (\\[i65, i66, i67, i68] -> [kfromS (m59 !$ [i65, i67]), kfromS (m60 !$ [i66, i68])])) ; w74 = stranspose @[4,5,0,1,2,3] (sgather (stranspose @[2,3,0,1] (sfromR u1)) (\\[i69, i70, i71, i72] -> [kfromS (m59 !$ [i69, i71]), kfromS (m60 !$ [i70, i72])])) ; w75 = sreshape @[7,2,3,3,16] (w73 * w74) in rfromS (stranspose @[2,3,0,1] (sscatter (stranspose @[2,3,4,5,0,1] (w73 * sreshape @[7,2,3,3,4,4] (sscatter (sfromR dret) (\\[i81, i82, i83, i84] -> [i81, i82, i83, i84, kfromS (smaxIndex (w75 !$ [i81, i82, i83, i84]))])))) (\\[i85, i86, i87, i88] -> [kfromS (m59 !$ [i85, i87]), kfromS (m60 !$ [i86, i88])])))"
  printArtifactPretty (simplifyArtifact artifactRev)
    @?= "\\dret u1 -> rfromS (let w73 = sgather (sgather (sconcrete (sfromListLinear [2] [0.0,1.0])) (\\[i61, i62, i63, i64] -> [ifH (sscalar -0.0 <=. negate (sfromR u1 !$ [i63, i64, i61, i62])) 0 1])) (\\[i65, i66, i67, i68] -> [kfromS (sconcrete (sfromListLinear [3,4] [0,1,2,3,2,3,4,5,4,5,6,7]) !$ [i65, i67]), kfromS (sconcrete (sfromListLinear [3,4] [0,1,2,3,2,3,4,5,4,5,6,7]) !$ [i66, i68])]) in stranspose @[2,3,0,1] (sscatter (w73 * stranspose @[2,3,4,5,0,1] (sreshape @[7,2,3,3,4,4] (sscatter (sfromR dret) (\\[i81, i82, i83, i84] -> [i81, i82, i83, i84, kfromS (smaxIndex (sgather (stranspose @[4,5,0,1,2,3] w73 * stranspose @[4,5,0,1,2,3] (sgather (stranspose @[2,3,0,1] (sfromR u1)) (\\[i69, i70, i71, i72] -> [kfromS (sconcrete (sfromListLinear [3,4] [0,1,2,3,2,3,4,5,4,5,6,7]) !$ [i69, i71]), kfromS (sconcrete (sfromListLinear [3,4] [0,1,2,3,2,3,4,5,4,5,6,7]) !$ [i70, i72])]))) (\\[i118] -> [remH (quotH ((((288 * i81 + 144 * i82) + 48 * i83) + 16 * i84) + i118) 288) 7, remH (quotH ((((288 * i81 + 144 * i82) + 48 * i83) + 16 * i84) + i118) 144) 2, remH (quotH ((((288 * i81 + 144 * i82) + 48 * i83) + 16 * i84) + i118) 48) 3, remH (quotH ((((288 * i81 + 144 * i82) + 48 * i83) + 16 * i84) + i118) 16) 3, remH (quotH ((((288 * i81 + 144 * i82) + 48 * i83) + 16 * i84) + i118) 4) 4, remH ((((288 * i81 + 144 * i82) + 48 * i83) + 16 * i84) + i118) 4])))])))) (\\[i85, i86, i87, i88] -> [kfromS (sconcrete (sfromListLinear [3,4] [0,1,2,3,2,3,4,5,4,5,6,7]) !$ [i85, i87]), kfromS (sconcrete (sfromListLinear [3,4] [0,1,2,3,2,3,4,5,4,5,6,7]) !$ [i86, i88])])))"
  printAstPretty (simplifyInlineContractNoExpand $ artDerivativeRev artifactRev)
    @?= "rfromS (let w73 = sgather (sgather (sconcrete (sfromListLinear [2] [0.0,1.0])) (\\[i61, i62, i63, i64] -> [ifH (sscalar -0.0 <=. negate (sfromR u1 !$ [i63, i64, i61, i62])) 0 1])) (\\[i65, i66, i67, i68] -> [kfromS (sconcrete (sfromListLinear [3,4] [0,1,2,3,2,3,4,5,4,5,6,7]) !$ [i65, i67]), kfromS (sconcrete (sfromListLinear [3,4] [0,1,2,3,2,3,4,5,4,5,6,7]) !$ [i66, i68])]) in stranspose @[2,3,0,1] (sscatter (w73 * stranspose @[2,3,4,5,0,1] (sreshape @[7,2,3,3,4,4] (sscatter (sfromR u80) (\\[i81, i82, i83, i84] -> [i81, i82, i83, i84, kfromS (smaxIndex (sreshape @[7,2,3,3,16] (stranspose @[4,5,0,1,2,3] w73 * stranspose @[4,5,0,1,2,3] (sgather (stranspose @[2,3,0,1] (sfromR u1)) (\\[i69, i70, i71, i72] -> [kfromS (sconcrete (sfromListLinear [3,4] [0,1,2,3,2,3,4,5,4,5,6,7]) !$ [i69, i71]), kfromS (sconcrete (sfromListLinear [3,4] [0,1,2,3,2,3,4,5,4,5,6,7]) !$ [i70, i72])]))) !$ [i81, i82, i83, i84]))])))) (\\[i85, i86, i87, i88] -> [kfromS (sconcrete (sfromListLinear [3,4] [0,1,2,3,2,3,4,5,4,5,6,7]) !$ [i85, i87]), kfromS (sconcrete (sfromListLinear [3,4] [0,1,2,3,2,3,4,5,4,5,6,7]) !$ [i86, i88])])))"

testCNNOPP5cW :: Assertion
testCNNOPP5cW = do
  resetVarCounter
  let artifactRev = revArtifactAdapt UseIncomingCotangent (relu . conv2dC) (FTKR [7, 2, 7, 7] (FTKScalar @Double))
  printArtifactPrimalPretty (simplifyArtifact artifactRev)
    @?= "\\u1 -> rfromS (let u45 = ssum @98 (stranspose @[4,0,1,2,3] (sreshape @[2,7,2,2,98] (str (sreplicate @7 (stranspose @[1,2,3,0] (sreplicate @1 (stranspose @[2,3,0,4,5,1] (sgather (stranspose @[4,2,0,3,1] (sgather (sconcrete (sfromListLinear [2,2,2,2] [5.0,2.0,-2.0,0.0,13.1,9.0,582934.0,2.99432,6.0,1.0,0.1,-0.2,8.0,-4.0,-335.0,26.0])) (\\[i114, i116] -> [i114 + i116]))) (\\[i42, i43] -> [i42 + i43])))))) * sreplicate @2 (str (sreplicate @2 (str (sreplicate @2 (str (sreplicate @1 (sfromR u1)))))))))) in sgather (sconcrete (sfromListLinear [2] [0.0,1.0])) (\\[i46, i47, i48, i49] -> [ifH (sscalar -0.0 <=. negate (u45 !$ [i46, i47, i48, i49])) 0 1]) * u45)"
  printArtifactPrimalPretty artifactRev
    @?= "\\u1 -> let w44 = str (sreplicate @7 (stranspose @[1,2,3,0] (sreplicate @1 (stranspose @[2,3,0,4,5,1] (sgather (stranspose @[4,2,0,3,1] (sgather (sconcrete (sfromListLinear [2,2,2,2] [5.0,2.0,-2.0,0.0,13.1,9.0,582934.0,2.99432,6.0,1.0,0.1,-0.2,8.0,-4.0,-335.0,26.0])) (\\[i40, i41] -> [i40 + i41]))) (\\[i42, i43] -> [i42 + i43])))))) ; u45 = ssum @98 (stranspose @[4,0,1,2,3] (sreshape @[2,7,2,2,98] (w44 * sreplicate @2 (str (sreplicate @2 (str (sreplicate @2 (str (sreplicate @1 (sfromR u1)))))))))) ; u50 = sgather (sconcrete (sfromListLinear [2] [0.0,1.0])) (\\[i46, i47, i48, i49] -> [ifH (sscalar -0.0 <=. negate (u45 !$ [i46, i47, i48, i49])) 0 1]) in rfromS (u50 * u45)"
  printArtifactPretty artifactRev
    @?= "\\dret u1 -> let w44 = str (sreplicate @7 (stranspose @[1,2,3,0] (sreplicate @1 (stranspose @[2,3,0,4,5,1] (sgather (stranspose @[4,2,0,3,1] (sgather (sconcrete (sfromListLinear [2,2,2,2] [5.0,2.0,-2.0,0.0,13.1,9.0,582934.0,2.99432,6.0,1.0,0.1,-0.2,8.0,-4.0,-335.0,26.0])) (\\[i40, i41] -> [i40 + i41]))) (\\[i42, i43] -> [i42 + i43])))))) ; u45 = ssum @98 (stranspose @[4,0,1,2,3] (sreshape @[2,7,2,2,98] (w44 * sreplicate @2 (str (sreplicate @2 (str (sreplicate @2 (str (sreplicate @1 (sfromR u1)))))))))) ; u50 = sgather (sconcrete (sfromListLinear [2] [0.0,1.0])) (\\[i46, i47, i48, i49] -> [ifH (sscalar -0.0 <=. negate (u45 !$ [i46, i47, i48, i49])) 0 1]) in rfromS (ssum @1 (str (ssum @2 (str (ssum @2 (str (ssum @2 (w44 * sreshape @[2,7,2,2,1,2,7,7] (stranspose @[1,2,3,4,0] (sreplicate @98 (u50 * sfromR dret)))))))))))"
  printArtifactPretty (simplifyArtifact artifactRev)
    @?= "\\dret u1 -> rfromS (let w44 = sgather (stranspose @[4,2,0,3,1] (sgather (sconcrete (sfromListLinear [2,2,2,2] [5.0,2.0,-2.0,0.0,13.1,9.0,582934.0,2.99432,6.0,1.0,0.1,-0.2,8.0,-4.0,-335.0,26.0])) (\\[i169, i171] -> [i169 + i171]))) (\\[i42, i43] -> [i42 + i43]) in ssum @2 (ssum @2 (sdot1In (stranspose @[1,6,0,4,5,3,2] (sreplicate @7 (stranspose @[3,2,1,4,5,0] w44))) (stranspose @[4,2,3,1,5,6,7,0] (sreshape @[2,7,2,2,1,2,7,7] (stranspose @[1,2,3,4,0] (sreplicate @98 (sgather (sconcrete (sfromListLinear [2] [0.0,1.0])) (\\[i46, i47, i48, i49] -> [ifH (sscalar -0.0 <=. negate (ssum0 (sgather (stranspose @[6,1,0,2,4,5,3] (sreplicate @7 (stranspose @[3,2,1,4,5,0] w44)) * sreplicate @2 (sreplicate @2 (str (sreplicate @2 (sfromR u1))))) (\\[i132] -> [remH (quotH ((((2744 * i46 + 392 * i47) + 196 * i48) + 98 * i49) + i132) 98) 2, remH (quotH ((((2744 * i46 + 392 * i47) + 196 * i48) + 98 * i49) + i132) 196) 2, remH (quotH ((((2744 * i46 + 392 * i47) + 196 * i48) + 98 * i49) + i132) 392) 7, remH (quotH ((((2744 * i46 + 392 * i47) + 196 * i48) + 98 * i49) + i132) 2744) 2, remH (quotH ((((2744 * i46 + 392 * i47) + 196 * i48) + 98 * i49) + i132) 49) 2, remH (quotH ((((2744 * i46 + 392 * i47) + 196 * i48) + 98 * i49) + i132) 7) 7, remH ((((2744 * i46 + 392 * i47) + 196 * i48) + 98 * i49) + i132) 7])))) 0 1]) * sfromR dret)))) !$ [0]))))"
  printAstPretty (simplifyInlineContractNoExpand $ artDerivativeRev artifactRev)
    @?= "rfromS (let w44 = sgather (stranspose @[4,2,0,3,1] (sgather (sconcrete (sfromListLinear [2,2,2,2] [5.0,2.0,-2.0,0.0,13.1,9.0,582934.0,2.99432,6.0,1.0,0.1,-0.2,8.0,-4.0,-335.0,26.0])) (\\[i40, i41] -> [i40 + i41]))) (\\[i42, i43] -> [i42 + i43]) in ssum @2 (ssum @2 (sdot1In (stranspose @[4,2,3,0,5,6,7,1] (sreplicate @7 (stranspose @[1,2,3,0] (sreplicate @1 (stranspose @[2,3,0,4,5,1] w44)))) !$ [0]) (stranspose @[4,2,3,1,5,6,7,0] (sreshape @[2,7,2,2,1,2,7,7] (stranspose @[1,2,3,4,0] (sreplicate @98 (sgather (sconcrete (sfromListLinear [2] [0.0,1.0])) (\\[i46, i47, i48, i49] -> [ifH (sscalar -0.0 <=. negate (ssum0 (sreshape @[2,7,2,2,98] (str (sreplicate @7 (stranspose @[1,2,3,0] (sreplicate @1 (stranspose @[2,3,0,4,5,1] w44)))) * sreplicate @2 (str (sreplicate @2 (str (sreplicate @2 (str (sreplicate @1 (sfromR u1)))))))) !$ [i46, i47, i48, i49]))) 0 1]) * sfromR u51)))) !$ [0]))))"

testCNNOPP5dW :: Assertion
testCNNOPP5dW = do
  resetVarCounter
  let artifactRev = revArtifactAdapt UseIncomingCotangent (maxPool2dUnpadded 4 2 . relu . conv2dC) (FTKR [7, 2, 7, 7] (FTKScalar @Double))
  printArtifactPrimalPretty (simplifyArtifact artifactRev)
    @?= "\\u1 -> rfromS (let u67 = ssum @98 (stranspose @[4,0,1,2,3] (sreshape @[2,7,2,2,98] (str (sreplicate @7 (stranspose @[1,2,3,0] (sreplicate @1 (stranspose @[2,3,0,4,5,1] (sgather (stranspose @[4,2,0,3,1] (sgather (sconcrete (sfromListLinear [2,2,2,2] [5.0,2.0,-2.0,0.0,13.1,9.0,582934.0,2.99432,6.0,1.0,0.1,-0.2,8.0,-4.0,-335.0,26.0])) (\\[i130, i132] -> [i130 + i132]))) (\\[i64, i65] -> [i64 + i65])))))) * sreplicate @2 (str (sreplicate @2 (str (sreplicate @2 (str (sreplicate @1 (sfromR u1)))))))))) ; t74 = sreshape @[2,7,16] (stranspose @[1,2,3,0] (sappend (stranspose @[3,1,2,0] (sappend (sgather (sconcrete (sfromListLinear [2] [0.0,1.0])) (\\[i68, i69, i70, i71] -> [ifH (sscalar -0.0 <=. negate (u67 !$ [i69, i70, i68, i71])) 0 1])) (sconcrete (sreplicate [2,2,7,2] 0.0)))) (sconcrete (sreplicate [2,2,7,4] 0.0))) * stranspose @[1,2,3,0] (sappend (stranspose @[3,1,2,0] (sappend (stranspose @[2,0,1] u67) (sconcrete (sreplicate [2,2,7,2] 0.0)))) (sconcrete (sreplicate [2,2,7,4] 0.0)))) in stranspose @[1,2,0] (sreplicate @1 (stranspose @[1,2,0] (sreplicate @1 (sgather t74 (\\[i75, i76] -> [i75, i76, kfromS (smaxIndex (t74 !$ [i75, i76]))]))))))"
  printArtifactPrimalPretty artifactRev
    @?= "\\u1 -> let w66 = str (sreplicate @7 (stranspose @[1,2,3,0] (sreplicate @1 (stranspose @[2,3,0,4,5,1] (sgather (stranspose @[4,2,0,3,1] (sgather (sconcrete (sfromListLinear [2,2,2,2] [5.0,2.0,-2.0,0.0,13.1,9.0,582934.0,2.99432,6.0,1.0,0.1,-0.2,8.0,-4.0,-335.0,26.0])) (\\[i62, i63] -> [i62 + i63]))) (\\[i64, i65] -> [i64 + i65])))))) ; u67 = ssum @98 (stranspose @[4,0,1,2,3] (sreshape @[2,7,2,2,98] (w66 * sreplicate @2 (str (sreplicate @2 (str (sreplicate @2 (str (sreplicate @1 (sfromR u1)))))))))) ; u72 = stranspose @[1,2,3,0] (sappend (stranspose @[3,1,2,0] (sappend (sgather (sconcrete (sfromListLinear [2] [0.0,1.0])) (\\[i68, i69, i70, i71] -> [ifH (sscalar -0.0 <=. negate (u67 !$ [i69, i70, i68, i71])) 0 1])) (sconcrete (sreplicate [2,2,7,2] 0.0)))) (sconcrete (sreplicate [2,2,7,4] 0.0))) ; u73 = stranspose @[1,2,3,0] (sappend (stranspose @[3,1,2,0] (sappend (stranspose @[2,0,1] u67) (sconcrete (sreplicate [2,2,7,2] 0.0)))) (sconcrete (sreplicate [2,2,7,4] 0.0))) ; t74 = sreshape @[2,7,16] (u72 * u73) in rfromS (stranspose @[1,2,0] (sreplicate @1 (stranspose @[1,2,0] (sreplicate @1 (sgather t74 (\\[i75, i76] -> [i75, i76, kfromS (smaxIndex (t74 !$ [i75, i76]))]))))))"
  printArtifactPretty artifactRev
    @?= "\\dret u1 -> let w66 = str (sreplicate @7 (stranspose @[1,2,3,0] (sreplicate @1 (stranspose @[2,3,0,4,5,1] (sgather (stranspose @[4,2,0,3,1] (sgather (sconcrete (sfromListLinear [2,2,2,2] [5.0,2.0,-2.0,0.0,13.1,9.0,582934.0,2.99432,6.0,1.0,0.1,-0.2,8.0,-4.0,-335.0,26.0])) (\\[i62, i63] -> [i62 + i63]))) (\\[i64, i65] -> [i64 + i65])))))) ; u67 = ssum @98 (stranspose @[4,0,1,2,3] (sreshape @[2,7,2,2,98] (w66 * sreplicate @2 (str (sreplicate @2 (str (sreplicate @2 (str (sreplicate @1 (sfromR u1)))))))))) ; u72 = stranspose @[1,2,3,0] (sappend (stranspose @[3,1,2,0] (sappend (sgather (sconcrete (sfromListLinear [2] [0.0,1.0])) (\\[i68, i69, i70, i71] -> [ifH (sscalar -0.0 <=. negate (u67 !$ [i69, i70, i68, i71])) 0 1])) (sconcrete (sreplicate [2,2,7,2] 0.0)))) (sconcrete (sreplicate [2,2,7,4] 0.0))) ; u73 = stranspose @[1,2,3,0] (sappend (stranspose @[3,1,2,0] (sappend (stranspose @[2,0,1] u67) (sconcrete (sreplicate [2,2,7,2] 0.0)))) (sconcrete (sreplicate [2,2,7,4] 0.0))) ; t74 = sreshape @[2,7,16] (u72 * u73) ; u80 = stranspose @[3,0,1,2] (u72 * sreshape @[2,7,4,4] (sscatter (ssum @1 (stranspose @[2,0,1] (ssum @1 (stranspose @[2,0,1] (sfromR dret))))) (\\[i78, i79] -> [i78, i79, kfromS (smaxIndex (t74 !$ [i78, i79]))]))) in rfromS (ssum @1 (str (ssum @2 (str (ssum @2 (str (ssum @2 (w66 * sreshape @[2,7,2,2,1,2,7,7] (stranspose @[1,2,3,4,0] (sreplicate @98 (stranspose @[1,2,0] (sslice (SNat @0) (SNat @2) (stranspose @[3,1,2,0] (sslice (SNat @0) (SNat @2) u80))))))))))))))"
  printArtifactPretty (simplifyArtifact artifactRev)
    @?= "\\dret u1 -> rfromS (let w66 = sgather (stranspose @[4,2,0,3,1] (sgather (sconcrete (sfromListLinear [2,2,2,2] [5.0,2.0,-2.0,0.0,13.1,9.0,582934.0,2.99432,6.0,1.0,0.1,-0.2,8.0,-4.0,-335.0,26.0])) (\\[i172, i174] -> [i172 + i174]))) (\\[i64, i65] -> [i64 + i65]) ; u67 = ssum @98 (stranspose @[4,0,1,2,3] (sreshape @[2,7,2,2,98] (str (sreplicate @7 (stranspose @[1,2,3,0] (sreplicate @1 (stranspose @[2,3,0,4,5,1] w66)))) * sreplicate @2 (str (sreplicate @2 (str (sreplicate @2 (str (sreplicate @1 (sfromR u1)))))))))) ; u72 = sappend (stranspose @[3,1,2,0] (sappend (sgather (sconcrete (sfromListLinear [2] [0.0,1.0])) (\\[i68, i69, i70, i71] -> [ifH (sscalar -0.0 <=. negate (u67 !$ [i69, i70, i68, i71])) 0 1])) (sconcrete (sreplicate [2,2,7,2] 0.0)))) (sconcrete (sreplicate [2,2,7,4] 0.0)) in ssum @2 (ssum @2 (sdot1In (stranspose @[1,6,0,4,5,3,2] (sreplicate @7 (stranspose @[3,2,1,4,5,0] w66))) (stranspose @[4,2,3,1,5,6,7,0] (sreshape @[2,7,2,2,1,2,7,7] (stranspose @[1,2,3,4,0] (sreplicate @98 (stranspose @[1,2,0] (sslice (SNat @0) (SNat @2) (stranspose @[3,1,2,0] (sslice (SNat @0) (SNat @2) u72))) * stranspose @[1,2,0] (sslice (SNat @0) (SNat @2) (stranspose @[3,1,2,0] (sslice (SNat @0) (SNat @2) (stranspose @[3,0,1,2] (sreshape @[2,7,4,4] (sscatter (stranspose @[2,3,0,1] (sfromR dret) !$ [0, 0]) (\\[i78, i79] -> [i78, i79, kfromS (smaxIndex (sgather (stranspose @[1,2,3,0] u72 * stranspose @[1,2,3,0] (sappend (stranspose @[3,1,2,0] (sappend (stranspose @[2,0,1] u67) (sconcrete (sreplicate [2,2,7,2] 0.0)))) (sconcrete (sreplicate [2,2,7,4] 0.0)))) (\\[i146] -> [remH (quotH ((112 * i78 + 16 * i79) + i146) 112) 2, remH (quotH ((112 * i78 + 16 * i79) + i146) 16) 7, remH (quotH ((112 * i78 + 16 * i79) + i146) 4) 4, remH ((112 * i78 + 16 * i79) + i146) 4])))]))))))))))) !$ [0]))))"
  printAstPretty (simplifyInlineContractNoExpand $ artDerivativeRev artifactRev)
    @?= "rfromS (let w66 = sgather (stranspose @[4,2,0,3,1] (sgather (sconcrete (sfromListLinear [2,2,2,2] [5.0,2.0,-2.0,0.0,13.1,9.0,582934.0,2.99432,6.0,1.0,0.1,-0.2,8.0,-4.0,-335.0,26.0])) (\\[i62, i63] -> [i62 + i63]))) (\\[i64, i65] -> [i64 + i65]) ; u67 = ssum @98 (stranspose @[4,0,1,2,3] (sreshape @[2,7,2,2,98] (str (sreplicate @7 (stranspose @[1,2,3,0] (sreplicate @1 (stranspose @[2,3,0,4,5,1] w66)))) * sreplicate @2 (str (sreplicate @2 (str (sreplicate @2 (str (sreplicate @1 (sfromR u1)))))))))) ; u72 = sappend (stranspose @[3,1,2,0] (sappend (sgather (sconcrete (sfromListLinear [2] [0.0,1.0])) (\\[i68, i69, i70, i71] -> [ifH (sscalar -0.0 <=. negate (u67 !$ [i69, i70, i68, i71])) 0 1])) (sconcrete (sreplicate [2,2,7,2] 0.0)))) (sconcrete (sreplicate [2,2,7,4] 0.0)) in ssum @2 (ssum @2 (sdot1In (stranspose @[4,2,3,0,5,6,7,1] (sreplicate @7 (stranspose @[1,2,3,0] (sreplicate @1 (stranspose @[2,3,0,4,5,1] w66)))) !$ [0]) (stranspose @[4,2,3,1,5,6,7,0] (sreshape @[2,7,2,2,1,2,7,7] (stranspose @[1,2,3,4,0] (sreplicate @98 (stranspose @[1,2,0] (sslice (SNat @0) (SNat @2) (stranspose @[3,1,2,0] (sslice (SNat @0) (SNat @2) u72))) * stranspose @[1,2,0] (sslice (SNat @0) (SNat @2) (stranspose @[3,1,2,0] (sslice (SNat @0) (SNat @2) (stranspose @[3,0,1,2] (sreshape @[2,7,4,4] (sscatter (stranspose @[2,3,0,1] (sfromR u77) !$ [0, 0]) (\\[i78, i79] -> [i78, i79, kfromS (smaxIndex (sreshape @[2,7,16] (stranspose @[1,2,3,0] u72 * stranspose @[1,2,3,0] (sappend (stranspose @[3,1,2,0] (sappend (stranspose @[2,0,1] u67) (sconcrete (sreplicate [2,2,7,2] 0.0)))) (sconcrete (sreplicate [2,2,7,4] 0.0)))) !$ [i78, i79]))]))))))))))) !$ [0]))))"
