{-# LANGUAGE OverloadedLists #-} -- | Tests of "MnistRnnRanked2" recurrent neural networks using a few different -- optimization pipelines. -- -- Not LSTM. -- Doesn't train without Adam, regardless of whether mini-batches used. It does -- train with Adam, but only after very carefully tweaking initialization. -- This is extremely sensitive to initial parameters, more than to anything -- else. Probably, gradient is vanishing if parameters are initialized -- with a probability distribution that doesn't have the right variance. See -- https://stats.stackexchange.com/questions/301285/what-is-vanishing-gradient. -- Regularization/normalization might help as well. module TestMnistRNNR ( testTrees ) where import Prelude import Control.Monad (foldM, unless) import System.IO (hPutStrLn, stderr) import System.Random import Test.Tasty import Test.Tasty.HUnit hiding (assert) import Text.Printf import Data.Array.Nested.Ranked.Shape import HordeAd import HordeAd.Core.Adaptor import HordeAd.Core.AstEnv import HordeAd.Core.AstFreshId import HordeAd.Core.AstInterpret import EqEpsilon import MnistData import MnistRnnRanked2 (ADRnnMnistParameters, ADRnnMnistParametersShaped) import MnistRnnRanked2 qualified -- TODO: optimize enough that it can run for one full epoch in reasonable time -- and then verify it trains down to ~20% validation error in a short enough -- time to include such a training run in tests. testTrees :: [TestTree] testTrees = [ tensorADValMnistTestsRNNRA , tensorADValMnistTestsRNNRI , tensorADValMnistTestsRNNRO ] -- POPL differentiation, straight via the ADVal instance of RankedTensor, -- which side-steps vectorization. mnistTestCaseRNNRA :: forall r. (Differentiable r, GoodScalar r, PrintfArg r, AssertEqualUpToEpsilon r) => String -> Int -> Int -> Int -> Int -> Int -> r -> TestTree mnistTestCaseRNNRA prefix epochs maxBatches width miniBatchSize totalBatchSize expected = withSNat width $ \(SNat @width) -> let targetInit = forgetShape $ fst $ randomValue @(Concrete (X (ADRnnMnistParametersShaped Concrete width r))) 0.23 (mkStdGen 44) name = prefix ++ ": " ++ unwords [ show epochs, show maxBatches , show width, show miniBatchSize , show $ widthSTK $ knownSTK @(X (ADRnnMnistParameters Concrete r)) , show (tsize knownSTK targetInit) ] ftest :: Int -> MnistDataBatchR r -> Concrete (X (ADRnnMnistParameters Concrete r)) -> r ftest batch_size mnistData pars = MnistRnnRanked2.rnnMnistTestR batch_size mnistData (fromTarget @Concrete pars) in testCase name $ do hPutStrLn stderr $ printf "\n%s: Epochs to run/max batches per epoch: %d/%d" prefix epochs maxBatches trainData <- map mkMnistDataR <$> loadMnistData trainGlyphsPath trainLabelsPath testData <- map mkMnistDataR . take (totalBatchSize * maxBatches) <$> loadMnistData testGlyphsPath testLabelsPath let testDataR = mkMnistDataBatchR testData f :: MnistDataBatchR r -> ADVal Concrete (X (ADRnnMnistParameters Concrete r)) -> ADVal Concrete (TKScalar r) f (glyphR, labelR) adinputs = MnistRnnRanked2.rnnMnistLossFusedR miniBatchSize (rconcrete glyphR, rconcrete labelR) (fromTarget @(ADVal Concrete) adinputs) runBatch :: ( Concrete (X (ADRnnMnistParameters Concrete r)) , StateAdam (X (ADRnnMnistParameters Concrete r)) ) -> (Int, [MnistDataR r]) -> IO ( Concrete (X (ADRnnMnistParameters Concrete r)) , StateAdam (X (ADRnnMnistParameters Concrete r)) ) runBatch (!parameters, !stateAdam) (k, chunk) = do let chunkR = map mkMnistDataBatchR $ filter (\ch -> length ch == miniBatchSize) $ chunksOf miniBatchSize chunk res@(parameters2, _) = sgdAdam @(MnistDataBatchR r) @(X (ADRnnMnistParameters Concrete r)) f chunkR parameters stateAdam trainScore = ftest (length chunk) (mkMnistDataBatchR chunk) parameters2 testScore = ftest ((totalBatchSize * maxBatches) `min` 10000) testDataR parameters2 lenChunk = length chunk unless (width < 10) $ do hPutStrLn stderr $ printf "\n%s: (Batch %d with %d points)" prefix k lenChunk hPutStrLn stderr $ printf "%s: Training error: %.2f%%" prefix ((1 - trainScore) * 100) hPutStrLn stderr $ printf "%s: Validation error: %.2f%%" prefix ((1 - testScore ) * 100) return res let runEpoch :: Int -> ( Concrete (X (ADRnnMnistParameters Concrete r)) , StateAdam (X (ADRnnMnistParameters Concrete r)) ) -> IO (Concrete (X (ADRnnMnistParameters Concrete r))) runEpoch n (params2, _) | n > epochs = return params2 runEpoch n paramsStateAdam@(!_, !_) = do unless (width < 10) $ hPutStrLn stderr $ printf "\n%s: [Epoch %d]" prefix n let trainDataShuffled = shuffle (mkStdGen $ n + 5) trainData chunks = take maxBatches $ zip [1 ..] $ chunksOf totalBatchSize trainDataShuffled res <- foldM runBatch paramsStateAdam chunks runEpoch (succ n) res ftk = tftk @Concrete (knownSTK @(X (ADRnnMnistParameters Concrete r))) targetInit res <- runEpoch 1 (targetInit, initialStateAdam ftk) let testErrorFinal = 1 - ftest ((totalBatchSize * maxBatches) `min` 10000) testDataR res testErrorFinal @?~ expected {-# SPECIALIZE mnistTestCaseRNNRA :: String -> Int -> Int -> Int -> Int -> Int -> Double -> TestTree #-} tensorADValMnistTestsRNNRA :: TestTree tensorADValMnistTestsRNNRA = testGroup "RNNR ADVal MNIST tests" [ mnistTestCaseRNNRA "RNNRA 1 epoch, 1 batch" 1 1 128 150 5000 (0.6026 :: Double) , mnistTestCaseRNNRA "RNNRA artificial 1 2 3 4 5" 2 3 4 5 50 (0.8933333 :: Float) , mnistTestCaseRNNRA "RNNRA artificial 5 4 3 2 1" 5 4 3 2 49 (0.8622448979591837 :: Double) , mnistTestCaseRNNRA "RNNRA 1 epoch, 0 batch" 1 0 128 150 50 (1.0 :: Float) ] -- POPL differentiation, with Ast term defined and vectorized only once, -- but differentiated anew in each gradient descent iteration. mnistTestCaseRNNRI :: forall r. (Differentiable r, GoodScalar r, PrintfArg r, AssertEqualUpToEpsilon r) => String -> Int -> Int -> Int -> Int -> Int -> r -> TestTree mnistTestCaseRNNRI prefix epochs maxBatches width miniBatchSize totalBatchSize expected = withSNat width $ \(SNat @width) -> let targetInit = forgetShape $ fst $ randomValue @(Concrete (X (ADRnnMnistParametersShaped Concrete width r))) 0.23 (mkStdGen 44) name = prefix ++ ": " ++ unwords [ show epochs, show maxBatches , show width, show miniBatchSize , show $ widthSTK $ knownSTK @(X (ADRnnMnistParameters Concrete r)) , show (tsize knownSTK targetInit) ] ftest :: Int -> MnistDataBatchR r -> Concrete (X (ADRnnMnistParameters Concrete r)) -> r ftest batch_size mnistData pars = MnistRnnRanked2.rnnMnistTestR batch_size mnistData (fromTarget @Concrete pars) in testCase name $ do hPutStrLn stderr $ printf "\n%s: Epochs to run/max batches per epoch: %d/%d" prefix epochs maxBatches trainData <- map mkMnistDataR <$> loadMnistData trainGlyphsPath trainLabelsPath testData <- map mkMnistDataR . take (totalBatchSize * maxBatches) <$> loadMnistData testGlyphsPath testLabelsPath let testDataR = mkMnistDataBatchR testData ftk = tftk @Concrete (knownSTK @(X (ADRnnMnistParameters Concrete r))) targetInit (_, _, var, varAst) <- funToAstRevIO ftk (varGlyph, astGlyph) <- funToAstIO (FTKR (miniBatchSize :$: sizeMnistHeightInt :$: sizeMnistWidthInt :$: ZSR) FTKScalar) id (varLabel, astLabel) <- funToAstIO (FTKR (miniBatchSize :$: sizeMnistLabelInt :$: ZSR) FTKScalar) id let ast :: AstTensor AstMethodLet FullSpan (TKScalar r) ast = simplifyInline $ MnistRnnRanked2.rnnMnistLossFusedR miniBatchSize (astGlyph, astLabel) (fromTarget varAst) f :: MnistDataBatchR r -> ADVal Concrete (X (ADRnnMnistParameters Concrete r)) -> ADVal Concrete (TKScalar r) f (glyph, label) varInputs = let env = extendEnv var varInputs emptyEnv envMnist = extendEnv varGlyph (rconcrete glyph) $ extendEnv varLabel (rconcrete label) env in interpretAstFull envMnist ast runBatch :: ( Concrete (X (ADRnnMnistParameters Concrete r)) , StateAdam (X (ADRnnMnistParameters Concrete r)) ) -> (Int, [MnistDataR r]) -> IO ( Concrete (X (ADRnnMnistParameters Concrete r)) , StateAdam (X (ADRnnMnistParameters Concrete r)) ) runBatch (!parameters, !stateAdam) (k, chunk) = do let chunkR = map mkMnistDataBatchR $ filter (\ch -> length ch == miniBatchSize) $ chunksOf miniBatchSize chunk res@(parameters2, _) = sgdAdam @(MnistDataBatchR r) @(X (ADRnnMnistParameters Concrete r)) f chunkR parameters stateAdam trainScore = ftest (length chunk) (mkMnistDataBatchR chunk) parameters2 testScore = ftest ((totalBatchSize * maxBatches) `min` 10000) testDataR parameters2 lenChunk = length chunk unless (width < 10) $ do hPutStrLn stderr $ printf "\n%s: (Batch %d with %d points)" prefix k lenChunk hPutStrLn stderr $ printf "%s: Training error: %.2f%%" prefix ((1 - trainScore) * 100) hPutStrLn stderr $ printf "%s: Validation error: %.2f%%" prefix ((1 - testScore ) * 100) return res let runEpoch :: Int -> ( Concrete (X (ADRnnMnistParameters Concrete r)) , StateAdam (X (ADRnnMnistParameters Concrete r)) ) -> IO (Concrete (X (ADRnnMnistParameters Concrete r))) runEpoch n (params2, _) | n > epochs = return params2 runEpoch n paramsStateAdam@(!_, !_) = do unless (width < 10) $ hPutStrLn stderr $ printf "\n%s: [Epoch %d]" prefix n let trainDataShuffled = shuffle (mkStdGen $ n + 5) trainData chunks = take maxBatches $ zip [1 ..] $ chunksOf totalBatchSize trainDataShuffled res <- foldM runBatch paramsStateAdam chunks runEpoch (succ n) res res <- runEpoch 1 (targetInit, initialStateAdam ftk) let testErrorFinal = 1 - ftest ((totalBatchSize * maxBatches) `min` 10000) testDataR res testErrorFinal @?~ expected {-# SPECIALIZE mnistTestCaseRNNRI :: String -> Int -> Int -> Int -> Int -> Int -> Double -> TestTree #-} tensorADValMnistTestsRNNRI :: TestTree tensorADValMnistTestsRNNRI = testGroup "RNNR Intermediate MNIST tests" [ mnistTestCaseRNNRI "RNNRI 1 epoch, 1 batch" 1 1 128 150 5000 (0.6026 :: Double) , mnistTestCaseRNNRI "RNNRI artificial 1 2 3 4 5" 2 3 4 5 50 (0.8933333 :: Float) , mnistTestCaseRNNRI "RNNRI artificial 5 4 3 2 1" 5 4 3 2 49 (0.8622448979591837 :: Double) , mnistTestCaseRNNRI "RNNRI 1 epoch, 0 batch" 1 0 128 150 50 (1.0 :: Float) ] -- JAX differentiation, Ast term built and differentiated only once -- and the result interpreted with different inputs in each gradient -- descent iteration. mnistTestCaseRNNRO :: forall r. ( Differentiable r, GoodScalar r , PrintfArg r, AssertEqualUpToEpsilon r, ADTensorScalar r ~ r ) => String -> Int -> Int -> Int -> Int -> Int -> r -> TestTree mnistTestCaseRNNRO prefix epochs maxBatches width miniBatchSize totalBatchSize expected = withSNat width $ \(SNat @width) -> let targetInit = forgetShape $ fst $ randomValue @(Concrete (X (ADRnnMnistParametersShaped Concrete width r))) 0.23 (mkStdGen 44) name = prefix ++ ": " ++ unwords [ show epochs, show maxBatches , show width, show miniBatchSize , show $ widthSTK $ knownSTK @(X (ADRnnMnistParameters Concrete r)) , show (tsize knownSTK targetInit) ] ftest :: Int -> MnistDataBatchR r -> Concrete (X (ADRnnMnistParameters Concrete r)) -> r ftest batch_size mnistData pars = MnistRnnRanked2.rnnMnistTestR batch_size mnistData (fromTarget @Concrete pars) in testCase name $ do hPutStrLn stderr $ printf "\n%s: Epochs to run/max batches per epoch: %d/%d" prefix epochs maxBatches trainData <- map mkMnistDataR <$> loadMnistData trainGlyphsPath trainLabelsPath testData <- map mkMnistDataR . take (totalBatchSize * maxBatches) <$> loadMnistData testGlyphsPath testLabelsPath let testDataR = mkMnistDataBatchR testData dataInit = case chunksOf miniBatchSize testData of d : _ -> let (dglyph, dlabel) = mkMnistDataBatchR d in (rconcrete dglyph, rconcrete dlabel) [] -> error "empty test data" f :: ( ADRnnMnistParameters (AstTensor AstMethodLet FullSpan) r , ( AstTensor AstMethodLet FullSpan (TKR 3 r) , AstTensor AstMethodLet FullSpan (TKR 2 r) ) ) -> AstTensor AstMethodLet FullSpan (TKScalar r) f = \ (pars, (glyphR, labelR)) -> MnistRnnRanked2.rnnMnistLossFusedR miniBatchSize (rprimalPart glyphR, rprimalPart labelR) pars artRaw = gradArtifact f (fromTarget targetInit, dataInit) art = simplifyArtifactGradient artRaw go :: [MnistDataBatchR r] -> ( Concrete (X (ADRnnMnistParameters Concrete r)) , StateAdam (X (ADRnnMnistParameters Concrete r)) ) -> ( Concrete (X (ADRnnMnistParameters Concrete r)) , StateAdam (X (ADRnnMnistParameters Concrete r)) ) go [] (parameters, stateAdam) = (parameters, stateAdam) go ((glyph, label) : rest) (!parameters, !stateAdam) = let parametersAndInput = tpair parameters (tpair (rconcrete glyph) (rconcrete label)) gradient = tproject1 $ fst $ revInterpretArtifact art parametersAndInput Nothing in go rest (updateWithGradientAdam @(X (ADRnnMnistParameters Concrete r)) defaultArgsAdam stateAdam knownSTK parameters gradient) runBatch :: ( Concrete (X (ADRnnMnistParameters Concrete r)) , StateAdam (X (ADRnnMnistParameters Concrete r)) ) -> (Int, [MnistDataR r]) -> IO ( Concrete (X (ADRnnMnistParameters Concrete r)) , StateAdam (X (ADRnnMnistParameters Concrete r)) ) runBatch (!parameters, !stateAdam) (k, chunk) = do let chunkR = map mkMnistDataBatchR $ filter (\ch -> length ch == miniBatchSize) $ chunksOf miniBatchSize chunk res@(parameters2, _) = go chunkR (parameters, stateAdam) trainScore = ftest (length chunk) (mkMnistDataBatchR chunk) parameters2 testScore = ftest ((totalBatchSize * maxBatches) `min` 10000) testDataR parameters2 lenChunk = length chunk unless (width < 10) $ do hPutStrLn stderr $ printf "\n%s: (Batch %d with %d points)" prefix k lenChunk hPutStrLn stderr $ printf "%s: Training error: %.2f%%" prefix ((1 - trainScore) * 100) hPutStrLn stderr $ printf "%s: Validation error: %.2f%%" prefix ((1 - testScore ) * 100) return res let runEpoch :: Int -> ( Concrete (X (ADRnnMnistParameters Concrete r)) , StateAdam (X (ADRnnMnistParameters Concrete r)) ) -> IO (Concrete (X (ADRnnMnistParameters Concrete r))) runEpoch n (params2, _) | n > epochs = return params2 runEpoch n paramsStateAdam@(!_, !_) = do unless (width < 10) $ hPutStrLn stderr $ printf "\n%s: [Epoch %d]" prefix n let trainDataShuffled = shuffle (mkStdGen $ n + 5) trainData chunks = take maxBatches $ zip [1 ..] $ chunksOf totalBatchSize trainDataShuffled res <- foldM runBatch paramsStateAdam chunks runEpoch (succ n) res ftk = tftk @Concrete (knownSTK @(X (ADRnnMnistParameters Concrete r))) targetInit res <- runEpoch 1 (targetInit, initialStateAdam ftk) let testErrorFinal = 1 - ftest ((totalBatchSize * maxBatches) `min` 10000) testDataR res assertEqualUpToEpsilon 1e-1 expected testErrorFinal {-# SPECIALIZE mnistTestCaseRNNRO :: String -> Int -> Int -> Int -> Int -> Int -> Double -> TestTree #-} tensorADValMnistTestsRNNRO :: TestTree tensorADValMnistTestsRNNRO = testGroup "RNNR Once MNIST tests" [ mnistTestCaseRNNRO "RNNRO 1 epoch, 1 batch" 1 1 128 150 5000 (0.6026 :: Double) , mnistTestCaseRNNRO "RNNRO artificial 1 2 3 4 5" 2 3 4 5 50 (0.8933333 :: Float) , mnistTestCaseRNNRO "RNNRO artificial 5 4 3 2 1" 5 4 3 2 49 (0.8928571428571429 :: Double) , mnistTestCaseRNNRO "RNNRO 1 epoch, 0 batch" 1 0 128 150 50 (1.0 :: Float) ]