Skip to content

Commit

Permalink
minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
AugustUnderground committed Sep 15, 2024
1 parent 385254c commit 534a686
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 7 deletions.
8 changes: 4 additions & 4 deletions src/Run.hs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import HyperParameters
import Data.List (elemIndex)
import Data.Maybe (fromJust)
import Control.Monad (when)
import Control.Monad.State (gets, MonadIO (..), MonadState (..), StateT (..), evalStateT)
import Control.Monad.State (gets, MonadIO (..), MonadState (..), StateT (..))
import Torch ( Tensor, Dim (..), LearningRate
, KeepDim (..), Reduction (..))
import qualified Torch as T
Expand Down Expand Up @@ -61,7 +61,7 @@ validEpoch [] = do
validEpoch (b:bs) = do
s@TrainState{..} <- get
let l = validStep model b
put $ s {lossBuffer' = l : lossBuffer'}
put s {lossBuffer' = l : lossBuffer'}
validEpoch bs

-- | Training Step with Gradient
Expand Down Expand Up @@ -166,7 +166,7 @@ train num = do
let predict = scale' minY maxY . forward net' . scale minX maxX

traceModel dimX paramsX paramsY predict >>= saveInferenceModel modelDir
trace <- loadInferenceModel modelDir >>= noGrad . unTraceModel
net'' <- loadInferenceModel modelDir >>= noGrad . unTraceModel

testModel paramsY net'' datX' datY'

Expand All @@ -184,6 +184,6 @@ testModel :: [String] -> (T.Tensor -> T.Tensor) -> Tensor -> Tensor -> IO ()
testModel paramsY net xs ys = do
let ys' = net xs
mape = T.asValue @[Float] . T.meanDim (Dim 0) RemoveDim T.Float
. T.mulScalar 100 . T.abs $ (ys - ys') / ys
. T.mulScalar @Float 100.0 . T.abs $ (ys - ys') / ys
putStrLn "Prediction MAPEs"
mapM_ putStrLn $ zipWith (\p m -> p ++ ":\t" ++ show m ++ "%") paramsY mape
3 changes: 0 additions & 3 deletions stack.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
resolver: lts-22.28
compiler: ghc-9.6.5

system-ghc: true
install-ghc: false

ghc-options:
"$locals": -funfolding-use-threshold=16 -fexcess-precision -optc-O3 -optc-ffast-math -threaded -O2 -rtsopts -with-rtsopts=-N +RTS -xp -RTS # -v

Expand Down

0 comments on commit 534a686

Please sign in to comment.