diff --git a/boba/bobarun.py b/boba/bobarun.py index e569625..25da29f 100644 --- a/boba/bobarun.py +++ b/boba/bobarun.py @@ -1,5 +1,7 @@ import subprocess import os +import ast +import click from .lang import Lang from .wrangler import DIR_SCRIPT, DIR_LOG, get_universe_id_from_script, get_universe_log, get_universe_error_log, get_universe_name from subprocess import PIPE @@ -52,3 +54,12 @@ def run_commands_in_folder(folder, file_with_commands): for line in f.readlines(): os.system(line) os.chdir(cwd) + +class PythonDict(click.Option): + + def type_cast_value(self, ctx, value): + try: + return dict(ast.literal_eval(value)) + except: + raise click.BadParameter(value) + diff --git a/boba/cli.py b/boba/cli.py index c78a0c4..b08eeba 100644 --- a/boba/cli.py +++ b/boba/cli.py @@ -18,14 +18,19 @@ default='.', show_default=True) @click.option('--lang', help='Language, can be python/R [default: inferred from file extension]', default='') -def compile(script, out, lang): +@click.option('--decisions', '-d', cls=PythonDict, default='{}') +def compile(script, out, lang, decisions): """Generate multiverse analysis from specifications.""" check_path(script) click.echo('Creating multiverse from {}'.format(script)) ps = Parser(script, out, lang) - ps.main() + + if decisions: + ps.main(decisions) + else: + ps.main() ex = """To execute the multiverse, run the following commands: boba run --all diff --git a/boba/decisionparser.py b/boba/decisionparser.py index 474c245..2b8a351 100644 --- a/boba/decisionparser.py +++ b/boba/decisionparser.py @@ -12,6 +12,22 @@ class Decision: value: list desc: str = '' + def check_value(self, val): + """ Check if a value can be contained by this decision """ + for option in self.value: + if val == option: + return True + elif isinstance(option, dict) and isinstance(val, float): + if 'range' in option: + if val > option['range'][0] and val < option['range'][1]: + return True + if (('exclusive' in option and not option['exclusive']) or 'exclusive' not in option) and\ + (val == option['range'][0] or val == option['range'][1]): + return True + else: + return True + return False + class SamplingError(SyntaxError): pass @@ -47,7 +63,10 @@ def _is_id_token(s): @staticmethod def check_var_types(args, types, names): - for i in range(0, len(args)): + if len(args) != len(names): + raise ValueError('expected ' + str(len(names)) + ' arguments') + + for i in range(0, len(names)): if not isinstance(args[i], types[i]): raise ValueError(names[i] + ' must be of type ' + str(types[i])) @@ -66,14 +85,14 @@ def get_within_range(function, args, distribution_range, exclusive): return val @staticmethod - def random_uniform(minimum, maximum, args): + def random_uniform(distr_range, args): """randomly sample a number from a uniform distribution""" exclusive = args.get('exclusive', False) - DecisionParser.check_var_types([minimum, maximum, exclusive], - [float, float, bool], - ['min', 'max', 'exclusive']) + DecisionParser.check_var_types([distr_range, exclusive], + [list, bool], + ['range', 'exclusive']) - distr_range = [minimum, maximum] + DecisionParser.check_var_types(distr_range, [float, float], ['range[0]', 'range[1]']) return DecisionParser.get_within_range(random.uniform, distr_range, distr_range, exclusive) @staticmethod @@ -82,21 +101,17 @@ def rand_x_normal(function, args): mean = args.get('mean', 0.0) std_dev = args.get('std_dev', 1.0) exclusive = args.get('exclusive', False) - distribution_range = args.get('range', []) - DecisionParser.check_var_types([mean, std_dev, exclusive, distribution_range], + distr_range = args.get('range', []) + DecisionParser.check_var_types([mean, std_dev, exclusive, distr_range], [float, float, bool, list], ['mean', 'std_dev', 'exclusive', 'range']) - if len(distribution_range) == 0: - distribution_range = None - - if distribution_range and len(distribution_range) != 2: - raise ValueError('expected two items in range list') - elif distribution_range and len(distribution_range) == 2: - DecisionParser.check_var_types(distribution_range, [float, float], ['range[0]', 'range[1]']) - - if distribution_range: - return DecisionParser.get_within_range(function, [mean, std_dev], distribution_range, exclusive) + if len(distr_range) == 0: + distr_range = None + + if distr_range: + DecisionParser.check_var_types(distr_range, [float, float], ['range[0]', 'range[1]']) + return DecisionParser.get_within_range(function, [mean, std_dev], distr_range, exclusive) else: return function(mean, std_dev) @@ -114,7 +129,7 @@ def random_normal(args): def discretize(obj, discretization_method, count): """discretizes a continuous variable into 'count' descrete options.""" discretization_methods = { - 'uniform': DiscretizationFn(DecisionParser.random_uniform, ['min', 'max'], ['exclusive']), + 'uniform': DiscretizationFn(DecisionParser.random_uniform, ['range'], ['exclusive']), 'lognormal': DiscretizationFn(DecisionParser.random_lognormal, [], ['mean', 'std_dev', 'exclusive', 'range']), 'normal' : DiscretizationFn(DecisionParser.random_normal, [], ['mean', 'std_dev', 'exclusive', 'range']) } @@ -264,7 +279,7 @@ def get_decs(self): """Get a list of decision names.""" return [i for i in self.decisions.keys()] - def gen_code(self, template, dec_id, i_alt): + def gen_code(self, template, dec_id=None, i_alt=None, option=None): """ Replace the placeholder variable in a template chunk. :param template: a chunk of code with only one placeholder @@ -272,7 +287,12 @@ def gen_code(self, template, dec_id, i_alt): :param i_alt: which alternative :return: {string, string} replaced code and the value at this parameter """ - v = self.get_alt_discrete(dec_id, i_alt) + if dec_id is not None and i_alt is not None: + v = self.get_alt_discrete(dec_id, i_alt) + elif option is not None: + v = option + else: + raise ValueError("either dec_id and i_alt must not be None, or option must not be None") # assuming the placeholder var is always at the end # which is true given how we chop up the chunks diff --git a/boba/parser.py b/boba/parser.py index 57dfa1d..7c1f003 100644 --- a/boba/parser.py +++ b/boba/parser.py @@ -249,14 +249,14 @@ def _code_gen_recur(self, path, i, code, history): if chunk.variable != '': # check if we have already encountered the placeholder variable - prev_idx = None + prev_option = None for d in history.decisions: if d.parameter == chunk.variable: - prev_idx = d.idx + prev_option = d.option - if prev_idx is not None: + if prev_option is not None: # use the previous value - snippet, opt = self.dec_parser.gen_code(chunk.code, chunk.variable, prev_idx) + snippet, opt = self.dec_parser.gen_code(chunk.code, option=prev_option) self._code_gen_recur(path, i+1, code+snippet, history) else: # expand the decision @@ -272,7 +272,7 @@ def _code_gen_recur(self, path, i, code, history): continue # code gen - snippet, opt = self.dec_parser.gen_code(chunk.code, chunk.variable, k) + snippet, opt = self.dec_parser.gen_code(chunk.code, dec_id=chunk.variable, i_alt=k) decs = [a for a in history.decisions] decs.append(DecRecord(chunk.variable, opt, k)) self._code_gen_recur(path, i+1, code + snippet, @@ -281,7 +281,7 @@ def _code_gen_recur(self, path, i, code, history): code += chunk.code self._code_gen_recur(path, i+1, code, history) - def _code_gen(self): + def _code_gen(self, decrecords=None, path_filter=None): paths = self._get_code_paths() self.wrangler.counter = 0 # keep track of file name @@ -290,8 +290,46 @@ def _code_gen(self): + len(self.code_parser.get_decisions()) self.wrangler.create_dir() - for idx, p in enumerate(paths): - self._code_gen_recur(p, 0, '', History(idx)) + + if decrecords is None: + for idx, p in enumerate(paths): + self._code_gen_recur(p, 0, '', History(idx)) + elif path_filter is None: + for idx, p in enumerate(paths): + self._code_gen_recur(p, 0, '', History(idx, '', decrecords, [])) + else: + valid_paths = [] + for idx, p in enumerate(paths): + run_path = True + filtered_nodes = set() + for node in p: + if ':' in node[0]: + split = node[0].split(':') + name = split[0] + value = split[1] + + if name not in path_filter: + continue + + if path_filter[name] != value: + run_path = False + break + else: + filtered_nodes.add(name) + + for node in path_filter: + if node not in filtered_nodes: + run_path = False + break + + if run_path: + valid_paths.append((idx, p)) + + if not valid_paths: + raise ValueError('bad input to path_filter!') + + for path in valid_paths: + self._code_gen_recur(path[1], 0, '', History(path[0], '', decrecords, [])) # write the pre and post execs to a file. self.wrangler.write_pre_exe() @@ -387,9 +425,37 @@ def _warn_size(self): print('Aborted.') exit(0) - def main(self, verbose=True): + def main(self, decisions=None, verbose=True): self._warn_size() - self._code_gen() + + if decisions is not None: + decrecords = [] + code_path = 'path_filter' + if code_path in decisions: + path_filter = decisions[code_path] + else: + path_filter = None + + for key, value in decisions.items(): + if key == code_path: + continue + + dec_exist = False + for decname, dec in self.dec_parser.decisions.items(): + if decname == key: + if not dec.check_value(value): + raise ValueError('The decision "' + key + '" cannot have the value "' + str(value) + '"') + dec_exist = True + + if not dec_exist: + raise ValueError('The decision "' + key + '" does not exist') + + decrecords.append(DecRecord(key, value, 0)) + + self._code_gen(decrecords, path_filter) + else: + self._code_gen() + self._write_csv() self._write_server_config() if verbose: diff --git a/example/simple_cont/template.py b/example/simple_cont/template.py index 8ee3217..c085cc3 100644 --- a/example/simple_cont/template.py +++ b/example/simple_cont/template.py @@ -7,8 +7,7 @@ "seed" : 0, "sample" : "uniform", "count" : 50, - "min" : 1.0, - "max" : 3.0 + "range" : [1.0, 3.0] } ] } diff --git a/test/specs/continuous-err.json b/test/specs/continuous-err.json index 56075ed..bac1fe5 100644 --- a/test/specs/continuous-err.json +++ b/test/specs/continuous-err.json @@ -21,55 +21,45 @@ }, "4" : { "decisions": [ - {"var": "err", "options": [{"sample" : "uniform", "count" : 5, "min" : 0.0}] , "desc" : "check required variable omission"} + {"var": "err", "options": [{"sample" : "uniform", "count" : 5, "range" : [true, 1.0]}] , "desc" : "check bad type for variables"} ] }, "5" : { "decisions": [ - {"var": "err", "options": [{"sample" : "uniform", "count" : 5, "max" : 0.0}] , "desc" : "check required variable omission"} + {"var": "err", "options": [{"sample" : "uniform", "count" : 5, "range" : [1.0, "bad"]}] , "desc" : "check bad type for variables"} ] }, "6" : { - "decisions": [ - {"var": "err", "options": [{"sample" : "uniform", "count" : 5, "min" : true, "max" : 5.0}] , "desc" : "check bad type for variables"} - ] - }, - "7" : { - "decisions": [ - {"var": "err", "options": [{"sample" : "uniform", "count" : 5, "min" : 1.0, "max" : true}] , "desc" : "check bad type for variables"} - ] - }, - "8" : { "decisions": [ {"var": "err", "options": [{"sample" : "lognormal", "count" : 5, "exclusive" : 1.0}] , "desc" : "check bad type for variables"} ] }, - "9" : { + "7" : { "decisions": [ {"var": "err", "options": [{"sample" : "lognormal", "count" : 5, "mean" : "mean"}] , "desc" : "check bad type for variables"} ] }, - "10" : { + "8" : { "decisions": [ {"var": "err", "options": [{"sample" : "normal", "count" : 5, "range" : "range"}] , "desc" : "check bad type for variables"} ] }, - "11" : { + "9" : { "decisions": [ {"var": "err", "options": [{"sample" : "normal", "count" : 5, "range" : ["range", "range"]}] , "desc" : "check bad type for variables"} ] }, - "12" : { + "10" : { "decisions": [ {"var": "err", "options": [{"sample" : "normal", "count" : 5, "range" : [0.0, 1.0, 2.0]}] , "desc" : "check bad type for variables"} ] }, - "13" : { + "11" : { "decisions": [ {"var": "err", "options": [{"sample" : "normal", "count" : 5, "range" : [1.0, 0.0]}] , "desc" : "check bad type for variables"} ] }, - "14" : { + "12" : { "decisions": [ {"var": "err", "options": [{"sample" : "normal", "count" : 5, "std_dev" : 1.0, "range" : [1.0, 0.0]}] , "desc" : "check bad type for variables"} ] diff --git a/test/specs/continuous.json b/test/specs/continuous.json index 356825f..9b80d07 100644 --- a/test/specs/continuous.json +++ b/test/specs/continuous.json @@ -1,6 +1,6 @@ { "decisions": [ - {"var": "A", "options": [{"sample" : "uniform", "count" : 10, "seed" : 0, "min" : 0.0, "max" : 5.0}] , + {"var": "A", "options": [{"sample" : "uniform", "count" : 10, "seed" : 0,"range" : [0.0, 5.0]}] , "desc" : "uniform continuous variable expansion"}, {"var": "B", "options": [{"sample" : "lognormal", "count" : 10, "seed" : 0, "mean" : 0.0, "std_dev" : 5.0}] , @@ -9,7 +9,7 @@ {"var": "C", "options": [{"sample" : "normal", "count" : 10, "seed" : 0, "mean" : 0.0, "std_dev" : 5.0}] , "desc" : "normal continuous variable expansion"}, - {"var": "D", "options": [{"sample" : "uniform", "count" : 10, "seed" : 0, "min" : 0.0, "max" : 5.0}, 17.0] , + {"var": "D", "options": [{"sample" : "uniform", "count" : 10, "seed" : 0,"range" : [0.0, 5.0]}, 17.0] , "desc" : "uniform continuous variable expansion with additional constants"}, {"var": "E", "options": [{"sample" : "lognormal", "count" : 10, "seed" : 0, "mean" : 0.0, "std_dev" : 5.0}, 0.0, 1.0, 2.0] , @@ -18,17 +18,17 @@ {"var": "F", "options": [{"sample" : "normal", "count" : 10, "seed" : 0, "mean" : 0.0, "std_dev" : 5.0}, 0.0, 1.0, 2.0, 3.0, 4.0] , "desc" : "normal continuous variable expansion with additional constants"}, - {"var": "G", "options": [{"sample" : "uniform", "count" : 3, "seed" : 0, "min" : 0.0, "max" : 5.0}, + {"var": "G", "options": [{"sample" : "uniform", "count" : 3, "seed" : 0,"range" : [0.0, 5.0]}, {"sample" : "lognormal", "count" : 3, "seed" : 0, "mean" : 0.0, "std_dev" : 5.0}, {"sample" : "normal", "count" : 3, "seed" : 0, "mean" : 0.0, "std_dev" : 5.0}] , "desc" : "multiple continuous variable expansions"}, - {"var": "H", "options": [{"sample" : "uniform", "count" : 4, "seed" : 0, "min": 0.0, "max": 5.0}, - {"sample" : "uniform", "count" : 4, "seed" : 1, "min": 10.0, "max": 15.0}] , + {"var": "H", "options": [{"sample" : "uniform", "count" : 4, "seed" : 0, "range" : [0.0, 5.0]}, + {"sample" : "uniform", "count" : 4, "seed" : 1, "range" : [10.0, 15.0]}] , "desc" : "multiple continuous variable expansions"}, - {"var": "I", "options": [{"sample" : "uniform", "count" : 3, "seed" : 0, "min": 0.0, "max": 5.0}, -1.1, - {"sample" : "uniform", "count" : 3, "seed" : 1, "min": 10.0, "max": 15.0}, + {"var": "I", "options": [{"sample" : "uniform", "count" : 3, "seed" : 0, "range" : [0.0, 5.0]}, -1.1, + {"sample" : "uniform", "count" : 3, "seed" : 1, "range" : [10.0, 15.0]}, 0.0, 1.0, 2.0, 3.1415] , "desc" : "multiple continuous variable expansions with additional constants"}, diff --git a/test/test_decision_parser.py b/test/test_decision_parser.py index e4f3da0..89419c1 100644 --- a/test/test_decision_parser.py +++ b/test/test_decision_parser.py @@ -146,17 +146,15 @@ def test_continuous_err(self): "1" : ParseError, "2" : ParseError, "3" : DiscretizationError, - "4" : DiscretizationError, - "5" : DiscretizationError, + "4" : ValueError, + "5" : ValueError, "6" : ValueError, "7" : ValueError, "8" : ValueError, "9" : ValueError, "10" : ValueError, "11" : ValueError, - "12" : ValueError, - "13" : ValueError, - "14" : ValueError + "12" : ValueError } for name, error in expected_errs.items(): diff --git a/test/test_parser.py b/test/test_parser.py index 4ef99c3..cbc271b 100644 --- a/test/test_parser.py +++ b/test/test_parser.py @@ -6,6 +6,7 @@ sys.path.insert(0, os.path.abspath('..')) import unittest +import pandas from unittest.mock import patch import io from boba.parser import Parser @@ -25,6 +26,32 @@ def abs_path(rel_path): return os.path.join(os.path.dirname(__file__), rel_path) +def verify_summary_csv(summary, expected_values): + data = pandas.read_csv(summary) + for key, val in expected_values.items(): + if key not in data: + return False + + values = data[key].to_list() + if not all(elem == val for elem in values): + return False + + return True + + +class FilterTestCase: + def __init__(self, folder, script): + self.folder = folder + self.script = script + self.cases = [] + + def add_case(self, filt, expected=None): + if expected is None: + self.cases.append((filt, filt)) + else: + self.cases.append((filt, expected)) + + class TestParser(unittest.TestCase): # --- code gen --- @@ -257,6 +284,79 @@ def test_spec_cyclic_graph(self, stdout): Parser(base + 'script1-cyclic-graph.py') self.assertRegex(stdout.getvalue(), 'Cannot find any starting node') + def test_decision_filter(self): + base = abs_path('../example') + + filts = [] + + filt = FilterTestCase('simple', 'template.py') + filt.add_case({'path_filter' : {'A' : 'iqr'}}, {'A' : 'iqr'}) + filt.add_case({'cutoff' : 2.5}) + filt.add_case({'cutoff' : 3}) + filt.add_case({'path_filter' : {'A' : 'std'}, 'cutoff' : 2}, {'A' : 'std', 'cutoff' : 2}) + filts.append(filt) + + filt = FilterTestCase('simple_cont', 'template.py') + filt.add_case({'path_filter' : {'A' : 'iqr'}}, {'A' : 'iqr'}) + filt.add_case({'cutoff' : 1.0}) + filt.add_case({'cutoff' : 1.5}) + filt.add_case({'cutoff' : 2.0}) + filt.add_case({'cutoff' : 2.5}) + filt.add_case({'cutoff' : 3.0}) + filt.add_case({'path_filter' : {'A' : 'std'}, 'cutoff' : 2.25}, {'A' : 'std', 'cutoff' : 2.25}) + filt.add_case({'path_filter' : {'A' : 'iqr'}, 'cutoff' : 1.0}, {'A' : 'iqr', 'cutoff' : 1.0}) + filt.add_case({'path_filter' : {'A' : 'iqr'}, 'cutoff' : 3.0}, {'A' : 'iqr', 'cutoff' : 3.0}) + filts.append(filt) + + for filt in filts: + ps = Parser(base + '/' + filt.folder + '/' + filt.script, base + '/' + filt.folder) + for i in range(0, len(filt.cases)): + f = filt.cases[i] + ps.main(f[0], verbose=False) + csv = base + '/' + filt.folder + '/multiverse/summary.csv' + if not verify_summary_csv(csv, f[1]): + msg = 'failed on test case ' + str(i) + ' of ' + filt.folder + '\n' + msg += 'filter was: ' + str(f[0]) + ', expected ' + str(f[1]) + '\n' + self.fail(msg) + + def test_decision_filter_err(self): + base = abs_path('../example') + + filts = [] + + filt = FilterTestCase('simple', 'template.py') + filt.add_case({'path_filter' : {'C' : 'iqr'}}, ValueError) + filt.add_case({'path_filter' : {'A' : 'bad'}}, ValueError) + filt.add_case({'cutoff' : 'bad data'}, ValueError) + filt.add_case({'cutoff' : 1}, ValueError) + filts.append(filt) + + filt = FilterTestCase('simple_cont', 'template.py') + filt.add_case({'path_filter' : {'C' : 'iqr'}}, ValueError) + filt.add_case({'path_filter' : {'A' : 'bad'}}, ValueError) + filt.add_case({'cutoff' : 'bad data'}, ValueError) + filt.add_case({'cutoff' : 5.0}, ValueError) + filts.append(filt) + + for filt in filts: + ps = Parser(base + '/' + filt.folder + '/' + filt.script, base + '/' + filt.folder) + for i in range(0, len(filt.cases)): + f = filt.cases[i] + thrown = False + msg = 'failed on test case ' + str(i) + ' of ' + filt.folder + '\n' + try: + ps.main(f[0], verbose=False) + except Exception as e: + msg += 'filter was: ' + str(f[0]) + ', expected ' + str(f[1]) + ', recieved ' + str(e) + '\n' + if not isinstance(e, f[1]): + self.fail(msg) + + thrown = True + + if not thrown: + msg += 'no exception was thrown' + self.fail(msg) + if __name__ == '__main__': unittest.main()