Skip to content

Commit

Permalink
Test the sampling methods (#25)
Browse files Browse the repository at this point in the history
  • Loading branch information
dhadka authored Mar 24, 2020
1 parent 1c98d74 commit 031396a
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 0 deletions.
2 changes: 2 additions & 0 deletions rhodium/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ def sample_uniform(model, nsamples):

result.append(entry)

return result

def sample_lhs(model, nsamples):
"""Returns a data set with uncertainty parameters sampled using Latin hypercube sampling."""
if len(model.uncertainties) == 0:
Expand Down
48 changes: 48 additions & 0 deletions rhodium/test/sampling_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright 2015-2016 David Hadka
#
# This file is part of Rhodium, a Python module for robust decision making and
# exploratory modeling.
#
# Rhodium is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Rhodium is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Rhodium. If not, see <http://www.gnu.org/licenses/>.
from __future__ import division, print_function, absolute_import

import six
import unittest
from rhodium import *

class TestSampling(unittest.TestCase):

def testUniform(self):
model = Model("foo")
model.uncertainties = [UniformUncertainty("x", 5.0, 10.0)]

samples = sample_uniform(model, 100)

self.assertEquals(100, len(samples))

for i in range(len(samples)):
self.assertTrue("x" in samples[i])
self.assertTrue(samples[i]["x"] >= 5.0 and samples[i]["x"] <= 10.0)

def testLHS(self):
model = Model("foo")
model.uncertainties = [UniformUncertainty("x", 5.0, 10.0)]

samples = sample_lhs(model, 100)

self.assertEquals(100, len(samples))

for i in range(len(samples)):
self.assertTrue("x" in samples[i])
self.assertTrue(samples[i]["x"] >= 5.0 and samples[i]["x"] <= 10.0)

0 comments on commit 031396a

Please sign in to comment.