-- |
-- Module      : ConClusion.Array.Util
-- Description : Additional tools to work with numerical arrays
-- Copyright   : Phillip Seeber, 2022
-- License     : AGPL-3
-- Maintainer  : phillip.seeber@googlemail.com
-- Stability   : experimental
-- Portability : POSIX, Windows
module ConClusion.Array.Util
  ( IndexException (..),
    magnitude,
    normalise,
    angle,
    minDistAt,
    minDistAtVec,
    iMinimumM,
  )
where

import Data.Massiv.Array as Massiv hiding (IndexException)
import RIO
import System.IO.Unsafe (unsafePerformIO)

-- | Exception regarding indexing in some kind of aaray.
newtype IndexException = IndexException String deriving (Int -> IndexException -> ShowS
[IndexException] -> ShowS
IndexException -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [IndexException] -> ShowS
$cshowList :: [IndexException] -> ShowS
show :: IndexException -> String
$cshow :: IndexException -> String
showsPrec :: Int -> IndexException -> ShowS
$cshowsPrec :: Int -> IndexException -> ShowS
Show)

instance Exception IndexException

-- | Magnitude of a vector (length).
magnitude :: (Massiv.Numeric r e, Source r e, Floating e) => Massiv.Vector r e -> e
magnitude :: forall r e.
(Numeric r e, Source r e, Floating e) =>
Vector r e -> e
magnitude Vector r e
v = forall a. Floating a => a -> a
sqrt forall a b. (a -> b) -> a -> b
$ Vector r e
v forall r e.
(Numeric r e, Source r e) =>
Vector r e -> Vector r e -> e
!.! Vector r e
v

-- | Normalise a vector.
normalise :: (Massiv.Numeric r e, Source r e, Floating e) => Massiv.Vector r e -> Massiv.Vector r e
normalise :: forall r e.
(Numeric r e, Source r e, Floating e) =>
Vector r e -> Vector r e
normalise Vector r e
v = Vector r e
v forall ix r e.
(Index ix, Numeric r e) =>
Array r ix e -> e -> Array r ix e
.* (e
1 forall a. Fractional a => a -> a -> a
/ forall r e.
(Numeric r e, Source r e, Floating e) =>
Vector r e -> e
magnitude Vector r e
v)

-- | Angle between two vectors.
angle :: (Massiv.Numeric r e, Source r e, Floating e) => Massiv.Vector r e -> Massiv.Vector r e -> e
angle :: forall r e.
(Numeric r e, Source r e, Floating e) =>
Vector r e -> Vector r e -> e
angle Vector r e
a Vector r e
b = forall a. Floating a => a -> a
acos forall a b. (a -> b) -> a -> b
$ Vector r e
a forall r e.
(Numeric r e, Source r e) =>
Vector r e -> Vector r e -> e
!.! Vector r e
b forall a. Fractional a => a -> a -> a
/ (forall r e.
(Numeric r e, Source r e, Floating e) =>
Vector r e -> e
magnitude Vector r e
a forall a. Num a => a -> a -> a
* forall r e.
(Numeric r e, Source r e, Floating e) =>
Vector r e -> e
magnitude Vector r e
b)

-- | Find the minimal distance in a distance matrix, which is not the main diagonal.
{-# SCC minDistAt #-}
minDistAt ::
  ( Manifest r e,
    MonadThrow m,
    Ord e
  ) =>
  Massiv.Matrix r e ->
  m (e, Ix2)
minDistAt :: forall r e (m :: * -> *).
(Manifest r e, MonadThrow m, Ord e) =>
Matrix r e -> m (e, Ix2)
minDistAt Matrix r e
arr
  | forall ix r e. (Index ix, Size r) => Array r ix e -> Bool
isEmpty Matrix r e
arr = forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM forall a b. (a -> b) -> a -> b
$ forall ix. Index ix => Sz ix -> SizeException
SizeEmptyException (forall r ix e. Size r => Array r ix e -> Sz ix
Massiv.size Matrix r e
arr)
  | Bool
otherwise = forall (m :: * -> *) a. Monad m => a -> m a
return forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) ix r e a b.
(MonadIO m, Index ix, Source r e) =>
(a -> ix -> e -> a)
-> a -> (b -> a -> b) -> b -> Array r ix e -> m b
ifoldlP forall {a}. Ord a => (a, Ix2) -> Ix2 -> a -> (a, Ix2)
minFold (e, Ix2)
start forall {a} {b}. Ord a => (a, b) -> (a, b) -> (a, b)
chFold (e, Ix2)
start Matrix r e
arr
  where
    ix0 :: Ix2
ix0 = forall ix. Index ix => Int -> ix
pureIndex Int
0
    e0 :: e
e0 = Matrix r e
arr forall r ix e.
(HasCallStack, Manifest r e, Index ix) =>
Array r ix e -> ix -> e
Massiv.! Ix2
ix0
    start :: (e, Ix2)
start = (e
e0, Ix2
ix0)
    minFold :: (a, Ix2) -> Ix2 -> a -> (a, Ix2)
minFold acc :: (a, Ix2)
acc@(a
eA, Ix2
_) ix :: Ix2
ix@(Int
m :. Int
n) a
e = if a
e forall a. Ord a => a -> a -> Bool
< a
eA Bool -> Bool -> Bool
&& Int
m forall a. Ord a => a -> a -> Bool
> Int
n then (a
e, Ix2
ix) else (a, Ix2)
acc
    chFold :: (a, b) -> (a, b) -> (a, b)
chFold acc :: (a, b)
acc@(a
eA, b
_) ch :: (a, b)
ch@(a
e, b
_) = if a
e forall a. Ord a => a -> a -> Bool
< a
eA then (a, b)
ch else (a, b)
acc

-- | Find the minimal element of a vector, which is at a larger than the supplied index.
minDistAtVec ::
  ( Manifest r e,
    MonadThrow m,
    Ord e
  ) =>
  Ix1 ->
  Massiv.Vector r e ->
  m (e, Ix1)
minDistAtVec :: forall r e (m :: * -> *).
(Manifest r e, MonadThrow m, Ord e) =>
Int -> Vector r e -> m (e, Int)
minDistAtVec Int
ixStart Vector r e
vec
  | forall ix r e. (Index ix, Size r) => Array r ix e -> Bool
isEmpty Vector r e
vec = forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM forall a b. (a -> b) -> a -> b
$ forall ix. Index ix => Sz ix -> SizeException
SizeEmptyException (forall r ix e. Size r => Array r ix e -> Sz ix
Massiv.size Vector r e
vec)
  | Int
ixStart forall a. Ord a => a -> a -> Bool
>= Int
nElems = forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM forall a b. (a -> b) -> a -> b
$ forall ix. Index ix => Sz ix -> ix -> IndexException
IndexOutOfBoundsException (forall ix. Index ix => ix -> Sz ix
Sz Int
nElems) Int
ixStart
  | Bool
otherwise = do
      let (e
minE, Int
minIx) = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) ix r e a b.
(MonadIO m, Index ix, Source r e) =>
(a -> ix -> e -> a)
-> a -> (b -> a -> b) -> b -> Array r ix e -> m b
ifoldlP forall {a} {b}. Ord a => (a, b) -> b -> a -> (a, b)
minFold (e, Int)
startAcc forall {a}. Ord a => a -> a -> a
chFold (e, Int)
startAcc Vector r e
searchVec
      forall (m :: * -> *) a. Monad m => a -> m a
return (e
minE, Int
minIx forall a. Num a => a -> a -> a
+ Int
ixStart forall a. Num a => a -> a -> a
+ Int
1)
  where
    Sz Int
nElems = forall r ix e. Size r => Array r ix e -> Sz ix
Massiv.size Vector r e
vec
    searchVec :: Vector r e
searchVec = forall r e. Source r e => Sz Int -> Vector r e -> Vector r e
Massiv.drop (forall ix. Index ix => ix -> Sz ix
Sz forall a b. (a -> b) -> a -> b
$ Int
ixStart forall a. Num a => a -> a -> a
+ Int
1) Vector r e
vec
    ix0 :: Int
ix0 = Int
0
    e0 :: e
e0 = Vector r e
searchVec forall r ix e.
(HasCallStack, Manifest r e, Index ix) =>
Array r ix e -> ix -> e
Massiv.! Int
ix0
    startAcc :: (e, Int)
startAcc = (e
e0, Int
ix0)
    minFold :: (a, b) -> b -> a -> (a, b)
minFold acc :: (a, b)
acc@(a
eA, b
_) b
ix a
e = if a
e forall a. Ord a => a -> a -> Bool
< a
eA then (a
e, b
ix) else (a, b)
acc
    chFold :: a -> a -> a
chFold a
acc a
ch = forall {a}. Ord a => a -> a -> a
min a
acc a
ch

-- | Like 'Massiv.minimumM' but also returns the index of the minimal element.
iMinimumM ::
  ( Manifest r a,
    MonadThrow m,
    Index ix,
    Ord a
  ) =>
  Array r ix a ->
  m (a, ix)
iMinimumM :: forall r a (m :: * -> *) ix.
(Manifest r a, MonadThrow m, Index ix, Ord a) =>
Array r ix a -> m (a, ix)
iMinimumM Array r ix a
arr
  | forall ix r e. (Index ix, Size r) => Array r ix e -> Bool
isEmpty Array r ix a
arr = forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM forall a b. (a -> b) -> a -> b
$ forall ix. Index ix => Sz ix -> SizeException
SizeEmptyException (forall r ix e. Size r => Array r ix e -> Sz ix
Massiv.size Array r ix a
arr)
  | Bool
otherwise = forall (m :: * -> *) a. Monad m => a -> m a
return forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) ix r e a b.
(MonadIO m, Index ix, Source r e) =>
(a -> ix -> e -> a)
-> a -> (b -> a -> b) -> b -> Array r ix e -> m b
ifoldlP forall {a} {b}. Ord a => (a, b) -> b -> a -> (a, b)
minFold (a, ix)
start forall {a} {b}. Ord a => (a, b) -> (a, b) -> (a, b)
chFold (a, ix)
start Array r ix a
arr
  where
    ix0 :: ix
ix0 = forall ix. Index ix => Int -> ix
pureIndex Int
0
    e0 :: a
e0 = Array r ix a
arr forall r ix e.
(HasCallStack, Manifest r e, Index ix) =>
Array r ix e -> ix -> e
Massiv.! ix
ix0
    start :: (a, ix)
start = (a
e0, ix
ix0)

    minFold :: (a, b) -> b -> a -> (a, b)
minFold acc :: (a, b)
acc@(a
eA, b
_) b
ix a
e = if a
e forall a. Ord a => a -> a -> Bool
< a
eA then (a
e, b
ix) else (a, b)
acc
    chFold :: (a, b) -> (a, b) -> (a, b)
chFold acc :: (a, b)
acc@(a
eA, b
_) ch :: (a, b)
ch@(a
e, b
_) = if a
e forall a. Ord a => a -> a -> Bool
< a
eA then (a, b)
ch else (a, b)
acc