module Statistics.Regression
(
ols
, rSquare
) where
import Control.Applicative ((<$>))
import Prelude hiding (sum)
import Statistics.Function as F
import Statistics.Matrix
import Statistics.Matrix.Algorithms (qr)
import Statistics.Sample (mean)
import Statistics.Sample.Internal (sum)
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as M
ols :: Matrix
-> Vector
-> Vector
ols a b
| rs < cs = error $ "fewer rows than columns " ++ show d
| otherwise = solve r (transpose q `multiplyV` b)
where
d@(rs,cs) = dimension a
(q,r) = qr a
solve :: Matrix
-> Vector
-> Vector
solve r b
| n /= l = error $ "row/vector mismatch " ++ show (n,l)
| otherwise = U.create $ do
s <- U.thaw b
rfor n 0 $ \i -> do
si <- (/ unsafeIndex r i i) <$> M.unsafeRead s i
M.unsafeWrite s i si
for 0 i $ \j -> F.unsafeModify s j $ subtract ((unsafeIndex r j i) * si)
return s
where n = rows r
l = U.length b
rSquare :: Matrix
-> Vector
-> Vector
-> Double
rSquare pred resp coeff = 1 r / t
where
r = sum $ flip U.imap resp $ \i x -> square (x p i)
t = sum $ flip U.map resp $ \x -> square (x mean resp)
p i = sum . flip U.imap coeff $ \j -> (* unsafeIndex pred i j)