diff --git a/.github/workflows/continuous_integration.yml b/.github/workflows/continuous_integration.yml index ed0ab75f34..a688d5b959 100644 --- a/.github/workflows/continuous_integration.yml +++ b/.github/workflows/continuous_integration.yml @@ -25,29 +25,14 @@ jobs: with: python-version: ${{ matrix.python-version }} - - name: Install POSYDON without extras + - name: Install POSYDON with test dependencies run: | python -m pip install --upgrade pip - pip install . + pip install .[test] - name: Run all tests in posydon/unit_tests run: | - # python -m pip install --upgrade pip - # pip install . - pip install pytest - pip install pytest-cov export PATH_TO_POSYDON=./ export PATH_TO_POSYDON_DATA=./posydon/unit_tests/_data/ export MESA_DIR=./ - python -m pytest posydon/unit_tests/ \ - --cov=posydon.config \ - --cov=posydon.utils \ - --cov=posydon.grids \ - --cov=posydon.popsyn.IMFs \ - --cov=posydon.popsyn.norm_pop \ - --cov=posydon.popsyn.distributions \ - --cov=posydon.popsyn.star_formation_history \ - --cov=posydon.CLI \ - --cov-branch \ - --cov-report term-missing \ - --cov-fail-under=100 + pytest # run and coverage parameters are defined in pyproject.toml diff --git a/.gitmodules b/.gitmodules index 6cda6aff11..1cfccfceda 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,9 +1,3 @@ [submodule "grid_params/POSYDON-MESA-INLISTS"] path = grid_params/POSYDON-MESA-INLISTS url = https://github.com/POSYDON-code/POSYDON-MESA-INLISTS.git -[submodule "data/POSYDON_data"] - path = data/POSYDON_data - url = https://github.com/POSYDON-code/POSYDON_data.git -[submodule "posydon/tests/data/POSYDON-UNIT-TESTS"] - path = posydon/tests/data/POSYDON-UNIT-TESTS - url = https://github.com/POSYDON-code/POSYDON-UNIT-TESTS.git diff --git a/conda/meta.yaml b/conda/meta.yaml index eed55444e2..751d76aec0 100644 --- a/conda/meta.yaml +++ b/conda/meta.yaml @@ -1,9 +1,8 @@ {% set name = "posydon" %} -{% set version = "2.2.6" %} package: name: "{{ name|lower }}" - version: "{{ version }}" + version: {{ GIT_DESCRIBE_TAG }} source: path: .. @@ -17,7 +16,8 @@ requirements: host: - pip - python==3.11 - - setuptools>=38.2.5 + - setuptools>=76.0.0 + - setuptools-scm>=8.0 run: - python==3.11 diff --git a/data/POSYDON_data b/data/POSYDON_data deleted file mode 160000 index e5d8d77985..0000000000 --- a/data/POSYDON_data +++ /dev/null @@ -1 +0,0 @@ -Subproject commit e5d8d77985fc1502b6b6cc0400577623d50743ab diff --git a/posydon/__init__.py b/posydon/__init__.py index bf13af499c..d3c2c0325a 100644 --- a/posydon/__init__.py +++ b/posydon/__init__.py @@ -1,6 +1,11 @@ -from ._version import get_versions +from importlib.metadata import PackageNotFoundError, version + +try: + __version__ = version("posydon") +except PackageNotFoundError: + # Package is not installed + __version__ = "unknown" -__version__ = get_versions()['version'] __author__ = "Tassos Fragos " __credits__ = [ "Emmanouil Zapartas ", @@ -19,5 +24,3 @@ "Ying Qin <", "Aaron Dotter ", ] - -del get_versions diff --git a/posydon/_version.py b/posydon/_version.py deleted file mode 100644 index 030e6a8b1b..0000000000 --- a/posydon/_version.py +++ /dev/null @@ -1,532 +0,0 @@ - -# This file helps to compute a version number in source trees obtained from -# git-archive tarball (such as those provided by githubs download-from-tag -# feature). Distribution tarballs (built by setup.py sdist) and build -# directories (produced by setup.py build) will contain a much shorter file -# that just contains the computed version number. - -# This file is released into the public domain. Generated by -# versioneer-0.18 (https://github.com/warner/python-versioneer) - -"""Git implementation of _version.py.""" - -__authors__ = [ - "Scott Coughlin ", - "Matthias Kruckow ", -] - -import errno -import os -import re -import subprocess -import sys - - -def get_keywords(): - """Get the keywords needed to look up the version information.""" - # these strings will be replaced by git during git-archive. - # setup.py/versioneer.py will grep for the variable names, so they must - # each be defined on a line of their own. _version.py will just call - # get_keywords(). - git_refnames = "$Format:%d$" - git_full = "$Format:%H$" - git_date = "$Format:%ci$" - keywords = {"refnames": git_refnames, "full": git_full, "date": git_date} - return keywords - - -class VersioneerConfig: - """Container for Versioneer configuration parameters.""" - - -def get_config(): - """Create, populate and return the VersioneerConfig() object.""" - # these strings are filled in when 'setup.py versioneer' creates - # _version.py - cfg = VersioneerConfig() - cfg.VCS = "git" - cfg.style = "pep440" - cfg.tag_prefix = "v" - cfg.parentdir_prefix = "" - cfg.versionfile_source = "posydon/_version.py" - cfg.verbose = False - return cfg - - -class NotThisMethod(Exception): - """Exception raised if a method is not valid for the current scenario.""" - - -LONG_VERSION_PY = {} -HANDLERS = {} - - -def register_vcs_handler(vcs, method): # decorator - """Get decorator to mark a method as the handler for a particular VCS.""" - def decorate(f): - """Store f in HANDLERS[vcs][method].""" - if vcs not in HANDLERS: - HANDLERS[vcs] = {} - HANDLERS[vcs][method] = f - return f - return decorate - - -def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, - env=None): - """Call the given command(s).""" - assert isinstance(commands, list) - p = None - for c in commands: - try: - dispcmd = str([c] + args) - # remember shell=False, so use git.cmd on windows, not just git - p = subprocess.Popen([c] + args, cwd=cwd, env=env, - stdout=subprocess.PIPE, - stderr=(subprocess.PIPE if hide_stderr - else None)) - break - except EnvironmentError: - e = sys.exc_info()[1] - if e.errno == errno.ENOENT: - continue - if verbose: - print("unable to run %s" % dispcmd) - print(e) - return None, None - else: - if verbose: - print("unable to find command, tried %s" % (commands,)) - return None, None - stdout = p.communicate()[0].strip() - if sys.version_info[0] >= 3: - stdout = stdout.decode() - if p.returncode != 0: - if verbose: - print("unable to run %s (error)" % dispcmd) - print("stdout was %s" % stdout) - return None, p.returncode - return stdout, p.returncode - - -def versions_from_parentdir(parentdir_prefix, root, verbose): - """Try to determine the version from the parent directory name. - - Source tarballs conventionally unpack into a directory that includes both - the project name and a version string. We will also support searching up - two directory levels for an appropriately named parent directory - """ - rootdirs = [] - - for i in range(3): - dirname = os.path.basename(root) - if dirname.startswith(parentdir_prefix): - return {"version": dirname[len(parentdir_prefix):], - "full-revisionid": None, - "dirty": False, "error": None, "date": None} - else: - rootdirs.append(root) - root = os.path.dirname(root) # up a level - - if verbose: - print("Tried directories %s but none started with prefix %s" % - (str(rootdirs), parentdir_prefix)) - raise NotThisMethod("rootdir doesn't start with parentdir_prefix") - - -@register_vcs_handler("git", "get_keywords") -def git_get_keywords(versionfile_abs): - """Extract version information from the given file.""" - # the code embedded in _version.py can just fetch the value of these - # keywords. When used from setup.py, we don't want to import _version.py, - # so we do it with a regexp instead. This function is not used from - # _version.py. - keywords = {} - try: - f = open(versionfile_abs, "r") - for line in f.readlines(): - if line.strip().startswith("git_refnames ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["refnames"] = mo.group(1) - if line.strip().startswith("git_full ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["full"] = mo.group(1) - if line.strip().startswith("git_date ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["date"] = mo.group(1) - f.close() - except EnvironmentError: - pass - return keywords - - -@register_vcs_handler("git", "keywords") -def git_versions_from_keywords(keywords, tag_prefix, verbose): - """Get version information from git keywords.""" - if not keywords: - raise NotThisMethod("no keywords at all, weird") - date = keywords.get("date") - if date is not None: - # git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant - # datestamp. However we prefer "%ci" (which expands to an "ISO-8601 - # -like" string, which we must then edit to make compliant), because - # it's been around since git-1.5.3, and it's too difficult to - # discover which version we're using, or to work around using an - # older one. - date = date.strip().replace(" ", "T", 1).replace(" ", "", 1) - refnames = keywords["refnames"].strip() - if refnames.startswith("$Format"): - if verbose: - print("keywords are unexpanded, not using") - raise NotThisMethod("unexpanded keywords, not a git-archive tarball") - refs = set([r.strip() for r in refnames.strip("()").split(",")]) - # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of - # just "foo-1.0". If we see a "tag: " prefix, prefer those. - TAG = "tag: " - tags = set([r[len(TAG):] for r in refs if r.startswith(TAG)]) - if not tags: - # Either we're using git < 1.8.3, or there really are no tags. We use - # a heuristic: assume all version tags have a digit. The old git %d - # expansion behaves like git log --decorate=short and strips out the - # refs/heads/ and refs/tags/ prefixes that would let us distinguish - # between branches and tags. By ignoring refnames without digits, we - # filter out many common branch names like "release" and - # "stabilization", as well as "HEAD" and "master". - tags = set([r for r in refs if re.search(r'\d', r)]) - if verbose: - print("discarding '%s', no digits" % ",".join(refs - tags)) - if verbose: - print("likely tags: %s" % ",".join(sorted(tags))) - for ref in sorted(tags): - # sorting will prefer e.g. "2.0" over "2.0rc1" - if ref.startswith(tag_prefix): - r = ref[len(tag_prefix):] - if verbose: - print("picking %s" % r) - return {"version": r, - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": None, - "date": date} - # no suitable tags, so version is "0+unknown", but full hex is still there - if verbose: - print("no suitable tags, using unknown + full revision id") - return {"version": "0+unknown", - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": "no suitable tags", "date": None} - - -@register_vcs_handler("git", "pieces_from_vcs") -def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): - """Get version from 'git describe' in the root of the source tree. - - This only gets called if the git-archive 'subst' keywords were *not* - expanded, and _version.py hasn't already been rewritten with a short - version string, meaning we're inside a checked out source tree. - """ - GITS = ["git"] - if sys.platform == "win32": - GITS = ["git.cmd", "git.exe"] - - out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, - hide_stderr=True) - if rc != 0: - if verbose: - print("Directory %s not under git control" % root) - raise NotThisMethod("'git rev-parse --git-dir' returned error") - - # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] - # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = run_command(GITS, ["describe", "--tags", "--dirty", - "--always", "--long", - "--match", "%s*" % tag_prefix], - cwd=root) - # --long was added in git-1.5.5 - if describe_out is None: - raise NotThisMethod("'git describe' failed") - describe_out = describe_out.strip() - full_out, rc = run_command(GITS, ["rev-parse", "HEAD"], cwd=root) - if full_out is None: - raise NotThisMethod("'git rev-parse' failed") - full_out = full_out.strip() - - pieces = {} - pieces["long"] = full_out - pieces["short"] = full_out[:7] # maybe improved later - pieces["error"] = None - - # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] - # TAG might have hyphens. - git_describe = describe_out - - # look for -dirty suffix - dirty = git_describe.endswith("-dirty") - pieces["dirty"] = dirty - if dirty: - git_describe = git_describe[:git_describe.rindex("-dirty")] - - # now we have TAG-NUM-gHEX or HEX - - if "-" in git_describe: - # TAG-NUM-gHEX - mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) - if not mo: - # unparseable. Maybe git-describe is misbehaving? - pieces["error"] = ("unable to parse git-describe output: '%s'" - % describe_out) - return pieces - - # tag - full_tag = mo.group(1) - if not full_tag.startswith(tag_prefix): - if verbose: - fmt = "tag '%s' doesn't start with prefix '%s'" - print(fmt % (full_tag, tag_prefix)) - pieces["error"] = ("tag '%s' doesn't start with prefix '%s'" - % (full_tag, tag_prefix)) - return pieces - pieces["closest-tag"] = full_tag[len(tag_prefix):] - - # distance: number of commits since tag - pieces["distance"] = int(mo.group(2)) - - # commit: short hex revision ID - pieces["short"] = mo.group(3) - - else: - # HEX: no tags - pieces["closest-tag"] = None - count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], - cwd=root) - pieces["distance"] = int(count_out) # total number of commits - - # commit date: see ISO-8601 comment in git_versions_from_keywords() - date, rc = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"], - cwd=root) - if rc != 0: - if verbose: - print("Retry 'git show'") - date, rc = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"], - cwd=root) - if date is None: - raise NotThisMethod("'git show' failed") - pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) - - return pieces - - -def plus_or_dot(pieces): - """Return a + if we don't already have one, else return a .""" - if "+" in pieces.get("closest-tag", ""): - return "." - return "+" - - -def render_pep440(pieces): - """Build up version string, with post-release "local version identifier". - - Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you - get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty - - Exceptions: - 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += plus_or_dot(pieces) - rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - else: - # exception #1 - rendered = "0+untagged.%d.g%s" % (pieces["distance"], - pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - return rendered - - -def render_pep440_pre(pieces): - """TAG[.post.devDISTANCE] -- No -dirty. - - Exceptions: - 1: no tags. 0.post.devDISTANCE - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"]: - rendered += ".post.dev%d" % pieces["distance"] - else: - # exception #1 - rendered = "0.post.dev%d" % pieces["distance"] - return rendered - - -def render_pep440_post(pieces): - """TAG[.postDISTANCE[.dev0]+gHEX] . - - The ".dev0" means dirty. Note that .dev0 sorts backwards - (a dirty tree will appear "older" than the corresponding clean one), - but you shouldn't be releasing software with -dirty anyways. - - Exceptions: - 1: no tags. 0.postDISTANCE[.dev0] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - rendered += plus_or_dot(pieces) - rendered += "g%s" % pieces["short"] - else: - # exception #1 - rendered = "0.post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - rendered += "+g%s" % pieces["short"] - return rendered - - -def render_pep440_old(pieces): - """TAG[.postDISTANCE[.dev0]] . - - The ".dev0" means dirty. - - Eexceptions: - 1: no tags. 0.postDISTANCE[.dev0] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - else: - # exception #1 - rendered = "0.post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - return rendered - - -def render_git_describe(pieces): - """TAG[-DISTANCE-gHEX][-dirty]. - - Like 'git describe --tags --dirty --always'. - - Exceptions: - 1: no tags. HEX[-dirty] (note: no 'g' prefix) - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"]: - rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) - else: - # exception #1 - rendered = pieces["short"] - if pieces["dirty"]: - rendered += "-dirty" - return rendered - - -def render_git_describe_long(pieces): - """TAG-DISTANCE-gHEX[-dirty]. - - Like 'git describe --tags --dirty --always -long'. - The distance/hash is unconditional. - - Exceptions: - 1: no tags. HEX[-dirty] (note: no 'g' prefix) - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) - else: - # exception #1 - rendered = pieces["short"] - if pieces["dirty"]: - rendered += "-dirty" - return rendered - - -def render(pieces, style): - """Render the given version pieces into the requested style.""" - if pieces["error"]: - return {"version": "unknown", - "full-revisionid": pieces.get("long"), - "dirty": None, - "error": pieces["error"], - "date": None} - - if not style or style == "default": - style = "pep440" # the default - - if style == "pep440": - rendered = render_pep440(pieces) - elif style == "pep440-pre": - rendered = render_pep440_pre(pieces) - elif style == "pep440-post": - rendered = render_pep440_post(pieces) - elif style == "pep440-old": - rendered = render_pep440_old(pieces) - elif style == "git-describe": - rendered = render_git_describe(pieces) - elif style == "git-describe-long": - rendered = render_git_describe_long(pieces) - else: - raise ValueError("unknown style '%s'" % style) - - return {"version": rendered, "full-revisionid": pieces["long"], - "dirty": pieces["dirty"], "error": None, - "date": pieces.get("date")} - - -def get_versions(): - """Get version information or return default if unable to do so.""" - # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have - # __file__, we can work backwards from there to the root. Some - # py2exe/bbfreeze/non-CPython implementations don't do __file__, in which - # case we can only use expanded keywords. - - cfg = get_config() - verbose = cfg.verbose - - try: - return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, - verbose) - except NotThisMethod: - pass - - try: - root = os.path.realpath(__file__) - # versionfile_source is the relative path from the top of the source - # tree (where the .git directory might live) to this file. Invert - # this to find the root from __file__. - for i in cfg.versionfile_source.split('/'): - root = os.path.dirname(root) - except NameError: - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, - "error": "unable to find root of source tree", - "date": None} - - try: - pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose) - return render(pieces, cfg.style) - except NotThisMethod: - pass - - try: - if cfg.parentdir_prefix: - return versions_from_parentdir(cfg.parentdir_prefix, root, verbose) - except NotThisMethod: - pass - - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, - "error": "unable to compute version", "date": None} diff --git a/posydon/tests/active_learning/psy_cris/test_Classifier.py b/posydon/tests/active_learning/psy_cris/test_Classifier.py deleted file mode 100644 index 7c7d7f2953..0000000000 --- a/posydon/tests/active_learning/psy_cris/test_Classifier.py +++ /dev/null @@ -1,159 +0,0 @@ -"""Unit test for posydon.active_learning.psy_cris classes -""" -import unittest - -import matplotlib.pyplot as plt -import numpy as np -import pandas as pd - -from posydon.active_learning.psy_cris.classify import Classifier -from posydon.active_learning.psy_cris.data import TableData -from posydon.active_learning.psy_cris.synthetic_data.synth_data_3D import get_output_3D -from posydon.active_learning.psy_cris.utils import ( - get_random_grid_df, - get_regular_grid_df, -) - -# True for faster runtime ~ 3s vs 15s -SKIP_GP_TESTS = True - -SKIP_TEST_PLOTS = False -SHOW_PLOTS = False - - -class TestClassifier(unittest.TestCase): - """Test Classifier class on the 3d synthetic data set.""" - - @classmethod - def setUpClass(cls): - np.random.seed(12345) - cls.TEST_DATA_GRID = get_regular_grid_df(N=10 ** 3, dim=3) - cls.TEST_DATA_RAND = get_random_grid_df(N=10 ** 3, dim=3) - cls.UNIQUE_CLASSES = [1, 2, 3, 4, 6, 8] - - cls.TEST_INPUT_POINTS = np.array([[0, 0, 0], [-0.5, 0.5, 0.5]]) - cls.TEST_OUTPUT_TRUTH = get_output_3D(*cls.TEST_INPUT_POINTS.T) - - def setUp(self): - my_kwargs = {"n_neighbors": [2, 3]} - self.table_grid = self.create_TableData(self.TEST_DATA_GRID, **my_kwargs) - self.table_rand = self.create_TableData(self.TEST_DATA_RAND, **my_kwargs) - self.cls_obj_grid = Classifier(self.table_grid) - self.cls_obj_rand = Classifier(self.table_rand) - - def create_TableData(self, data_frame, **kwargs): - files = None - input_cols = ["input_1", "input_2", "input_3"] - output_cols = ["class", "output_1"] - class_col_name = "class" - table_obj = TableData( - files, - input_cols, - output_cols, - class_col_name, - my_DataFrame=data_frame, - verbose=False, - **kwargs - ) - return table_obj - - def test_train_classifiers_1(self): - # di can be [16,19,24,32,41,48,64] to avoid multi class error for gp - data_range = np.arange(0, 1000)[::64] - for cls in ["rbf", "linear", "gp"]: - with self.subTest("Train grid:", classifier=cls): - self.cls_obj_grid.train(cls, di=data_range, verbose=False) - - def test_train_classifiers_2(self): - # skipping gp here - data_range = np.arange(0, 1000)[::10] - for cls in ["rbf", "linear"]: - with self.subTest("Train random:", classifier=cls): - self.cls_obj_rand.train(cls, di=data_range, verbose=False) - - def train_grid_classifiers(self, cls_names, **kwargs): - for name in cls_names: - self.cls_obj_grid.train(name, **kwargs) - - def test_predictions(self): - self.train_grid_classifiers(["linear", "rbf"]) - correct_probabilities = [0.98519722, 1.00000, 0.5883452] - for i, cls_name in enumerate(["rbf", "linear"]): - with self.subTest("Get class predictions:", classifier=cls_name): - tup_out = self.cls_obj_grid.get_class_predictions( - cls_name, self.TEST_INPUT_POINTS, return_ids=False - ) - class_pred, probs, where_not_nan = tup_out - - self.assertTrue( - (3 in class_pred) and (8 in class_pred), - msg="All predictions should contain class 3 and 8", - ) - self.assertAlmostEqual(probs[0], correct_probabilities[i], places=3) - self.assertTrue(len(where_not_nan) == 2, msg="Should not get any nans.") - - @unittest.skipIf(SKIP_GP_TESTS, "GP train / predict - long runtime.") - def test_predictions_gp(self): - self.train_grid_classifiers(["gp"], di=np.arange(0, 1000)[::6]) - tup_out = self.cls_obj_grid.get_class_predictions( - "gp", self.TEST_INPUT_POINTS, return_ids=False - ) - class_pred, probs, where_not_nan = tup_out - - self.assertTrue( - (3 in class_pred) and (8 in class_pred), - msg="All predictions should contain class 3 and 8", - ) - self.assertAlmostEqual(probs[0], 0.5883452, places=3) - self.assertTrue(len(where_not_nan) == 2, msg="Should not get any nans.") - - def test_pred_train_err(self): - # Trying to predict without training - names = ["grid", "random"] - for i, classifier in enumerate([self.cls_obj_grid, self.cls_obj_rand]): - with self.subTest(classifier_name=names[i]): - with self.assertRaisesRegex( - Exception, "No trained interpolators exist" - ): - classifier.get_class_predictions("linear", [[0, 0, 0]]) - - def test_pred_linear_err(self): - self.cls_obj_rand.train("linear", di=np.arange(0, 1000, 50)) - tup_out = self.cls_obj_rand.get_class_predictions( - "lin", [[-1, -1, -1], [1, 1, 1]], return_ids=False - ) - self.assertTrue(len(tup_out[2]) == 0, msg="Should return no valid values.") - - # def test_cross_val(self): - # correct_ans = [67.36842105263158, 66.66666666666666] - # acc, times = self.cls_obj_grid.cross_validate( - # ["rbf", "linear"], 0.05, verbose=False - # ) - # for i, percent_acc in enumerate(acc): - # with self.subTest("Cross Val", i=i, percent_acc=percent_acc): - # self.assertAlmostEqual(acc[i], correct_ans[i], places=3) - - @unittest.skipIf(SKIP_GP_TESTS, "GP cross_val - long runtime.") - def test_cross_val_gp(self): - correct_ans = [73.76470588235294] - acc, times = self.cls_obj_grid.cross_validate(["gp"], 0.15, verbose=False) - for i, percent_acc in enumerate(acc): - with self.subTest("Cross Val", i=i, percent_acc=percent_acc): - self.assertAlmostEqual(acc[i], correct_ans[i], places=3) - - @unittest.skipIf(SKIP_TEST_PLOTS, "Skipping maximum class P plot.") - def test_max_cls_plot(self): - N = int(2e4) if SHOW_PLOTS else 100 - self.train_grid_classifiers(["rbf"]) - fig, axes = self.cls_obj_grid.make_max_cls_plot( - "rbf", ("input_1", "input_2"), N=N, s=3, alpha=0.6, cmap="bone" - ) - if SHOW_PLOTS: - fig.show() - else: - print("To show plots set SHOW_PLOTS to True.") - plt.close(fig) - - -if __name__ == "__main__": - unittest.main() diff --git a/posydon/tests/active_learning/psy_cris/test_Regressor.py b/posydon/tests/active_learning/psy_cris/test_Regressor.py deleted file mode 100644 index 94695db620..0000000000 --- a/posydon/tests/active_learning/psy_cris/test_Regressor.py +++ /dev/null @@ -1,174 +0,0 @@ -"""Unit test for posydon.active_learning.psy_cris classes -""" -import math -import unittest - -import matplotlib.pyplot as plt -import numpy as np -import pandas as pd - -from posydon.active_learning.psy_cris.data import TableData -from posydon.active_learning.psy_cris.regress import Regressor -from posydon.active_learning.psy_cris.synthetic_data.synth_data_3D import get_output_3D -from posydon.active_learning.psy_cris.utils import ( - get_random_grid_df, - get_regular_grid_df, -) - -# True for faster runtime ~ 1s vs 3s -SKIP_GP_TESTS = False - -SKIP_TEST_PLOTS = False -SHOW_PLOTS = False - - -class TestRegressor(unittest.TestCase): - """Test Regressor class on the 3d synthetic data set.""" - - @classmethod - def setUpClass(cls): - np.random.seed(12345) - cls.TEST_DATA_GRID = get_regular_grid_df(N=10 ** 3, dim=3) - cls.TEST_DATA_RAND = get_random_grid_df(N=10 ** 3, dim=3) - cls.UNIQUE_CLASSES = [1, 2, 3, 4, 6, 8] - - cls.TEST_INPUT_POINTS = np.array([[0, 0, 0], [-0.5, 0.5, 0.5]]) - cls.TEST_OUTPUT_TRUTH = get_output_3D(*cls.TEST_INPUT_POINTS.T) - - def setUp(self): - my_kwargs = {"n_neighbors": [2, 3, 5]} - self.table_grid = self.create_TableData(self.TEST_DATA_GRID, **my_kwargs) - self.table_rand = self.create_TableData(self.TEST_DATA_RAND, **my_kwargs) - self.regr_grid = Regressor(self.table_grid) - self.regr_rand = Regressor(self.table_rand) - - def create_TableData(self, data_frame, **kwargs): - files = None - input_cols = ["input_1", "input_2", "input_3"] - output_cols = ["class", "output_1"] - class_col_name = "class" - table_obj = TableData( - files, - input_cols, - output_cols, - class_col_name, - my_DataFrame=data_frame, - verbose=False, - **kwargs - ) - return table_obj - - def train_grid_regressors(self, names, *args, **kwargs): - for i, regr_name in enumerate(names): - self.regr_grid.train(regr_name, *args, **kwargs) - - def test_train_regressors_1(self): - # di can be [16,19,24,32,41,48,64] to avoid multi class error for gp - classes_to_train = [ [1,2,3,4,6,8], [1,6,8], [1,6,8] ] - col_keys = ["output_1"] - for i, regr in enumerate(["rbf", "linear", "gp"]): - with self.subTest("Train grid:", regressor=regr): - self.regr_grid.train(regr, classes_to_train[i], col_keys, - verbose=False) - - def test_train_regressors_2(self): - # skipping gp here - classes_to_train = [ [1,2,3,4,6,8], [1,6,8], [1,6,8] ] - col_keys = ["output_1"] - for i, regr in enumerate(["rbf", "linear"]): - with self.subTest("Train random:", regressor=regr): - self.regr_rand.train(regr, classes_to_train[i], - col_keys, verbose=False) - - def test_predictions(self): - """Checking for consistency only, not true values.""" - self.train_grid_regressors(["rbf", "linear"], [6], ["output_1"]) - - regr_out = self.regr_grid.get_predictions( - ["rbf", "linear"], [6], ["output_1"], self.TEST_INPUT_POINTS - ) - # RBF check - for i, corr_ans in enumerate([-0.78295781, -0.7401497]): - with self.subTest("RBF regr", correct_ans=corr_ans): - pred = regr_out["RBF"][6]["output_1"][i] - self.assertAlmostEqual(pred, corr_ans, places=5) - - # LinearNDInterpolator check - for i, corr_ans in enumerate([0.13898644, float("Nan")]): - with self.subTest("LinearNDInterpolator regr", correct_ans=corr_ans): - pred = regr_out["LinearNDInterpolator"][6]["output_1"][i] - if i == 1: - self.assertTrue(math.isnan(pred), msg="Prediction should be Nan. {}".format(pred)) - else: - self.assertAlmostEqual(pred, corr_ans, places=5) - - @unittest.skipIf(SKIP_GP_TESTS, "GP train / predict - longer runtime.") - def test_predictions_gp(self): - self.train_grid_regressors(["gp"], [6], ["output_1"], di=None) - - regr_out = self.regr_grid.get_predictions( - ["gp"], [6], ["output_1"], self.TEST_INPUT_POINTS - ) - for i, corr_ans in enumerate([0,0]): - with self.subTest("GaussianProcessRegressor", correct_ans=corr_ans): - pred = regr_out["GaussianProcessRegressor"][6]["output_1"][i] - self.assertAlmostEqual(pred, corr_ans, places=5) - - - def test_pred_train_err(self): - # Trying to predict without training - names = ["grid", "random"] - for i, regressor in enumerate([self.regr_grid, self.regr_rand]): - with self.subTest(classifier_name=names[i]): - with self.assertRaisesRegex( - Exception, "No trained interpolators exist" - ): - regressor.get_predictions( ["linear"], [6], ["output_1"], [[0, 0, 0]]) - - def test_cross_val(self): - corr_ans = [-16.528141760898365, -61.41327730988214, -10.621903342287425] - for index, cls in enumerate([1,6,8]): - with self.subTest("Cross Validation Regression", class_key=cls): - perc_diffs, actual_diffs = self.regr_grid.cross_validate("rbf", cls, "output_1", 0.5 ) - self.assertAlmostEqual( np.mean(perc_diffs), corr_ans[index], places=5) - - plt.hist(perc_diffs, bins=40, density=True, range=(-300,300), - histtype="step", label="class "+str(cls)) - plt.xlabel("Regression Percent Difference") - plt.title("Test Regression CV") - plt.legend() - if SHOW_PLOTS: - plt.show() - plt.close() - - - @unittest.skipIf(SKIP_GP_TESTS, "GP cross_val - longer runtime.") - def test_cross_val_gp(self): - corr_ans = [-36.08266156603506, -100.0, -70.79557229587994] - for index, cls in enumerate([1,6,8]): - with self.subTest("Cross Validation Regression GP", class_key=cls, ans=corr_ans[index]): - perc_diffs, actual_diffs = self.regr_grid.cross_validate("gp", cls, "output_1", 0.5 ) - self.assertAlmostEqual( np.mean(perc_diffs), corr_ans[index], places=5) - - plt.hist(perc_diffs, bins=40, density=True, range=(-300,300), - histtype="step", label="class "+str(cls)) - plt.xlabel("Regression Percent Difference") - plt.title("Test GP Regression CV") - plt.legend() - if SHOW_PLOTS: - plt.show() - plt.close() - - @unittest.skipIf(SKIP_TEST_PLOTS, "All regression data plot.") - def test_max_cls_plot(self): - class_key = 1 - fig = self.regr_grid.plot_regr_data(class_key) - if SHOW_PLOTS: - plt.show() - else: - print("To show plots set SHOW_PLOTS to True.") - plt.close() - - -if __name__ == "__main__": - unittest.main() diff --git a/posydon/tests/active_learning/psy_cris/test_Sampler.py b/posydon/tests/active_learning/psy_cris/test_Sampler.py deleted file mode 100644 index 8ac35611dc..0000000000 --- a/posydon/tests/active_learning/psy_cris/test_Sampler.py +++ /dev/null @@ -1,166 +0,0 @@ -"""Unit test for posydon.active_learning.psy_cris classes -""" -import math -import unittest - -import matplotlib.pyplot as plt -import numpy as np -import pandas as pd - -from posydon.active_learning.psy_cris.classify import Classifier -from posydon.active_learning.psy_cris.data import TableData -from posydon.active_learning.psy_cris.regress import Regressor -from posydon.active_learning.psy_cris.sample import Sampler -from posydon.active_learning.psy_cris.synthetic_data.synth_data_3D import get_output_3D -from posydon.active_learning.psy_cris.utils import ( - get_random_grid_df, - get_regular_grid_df, -) - -SKIP_TEST_PLOTS = False -SHOW_PLOTS = False - - -class TestSampler(unittest.TestCase): - """Test Sampler class on the 3d synthetic data set.""" - - @classmethod - def setUpClass(cls): - np.random.seed(12345) - cls.TEST_DATA_GRID = get_regular_grid_df(N=10 ** 3, dim=3) - cls.TEST_DATA_RAND = get_random_grid_df(N=10 ** 3, dim=3) - cls.UNIQUE_CLASSES = [1, 2, 3, 4, 6, 8] - - cls.TEST_INPUT_POINTS = np.array([[0, 0, 0], [-0.5, 0.5, 0.5]]) - cls.TEST_OUTPUT_TRUTH = get_output_3D(*cls.TEST_INPUT_POINTS.T) - - def setUp(self): - my_kwargs = {"n_neighbors": [2, 3, 5]} - self.table_grid = self.create_TableData(self.TEST_DATA_GRID, **my_kwargs) - self.regr_grid = Regressor(self.table_grid) - self.cls_grid = Classifier(self.table_grid) - - def create_TableData(self, data_frame, **kwargs): - files = None - input_cols = ["input_1", "input_2", "input_3"] - output_cols = ["class", "output_1"] - class_col_name = "class" - table_obj = TableData( - files, - input_cols, - output_cols, - class_col_name, - my_DataFrame=data_frame, - verbose=False, - **kwargs - ) - return table_obj - - def train_everything_grid(self, cls_names, regr_names): - if cls_names is not None: - self.cls_grid.train_everything(cls_names) - if regr_names is not None: - self.regr_grid.train_everything(regr_names) - - def test_init_0(self): - test_cases = [ - (None, None), - (self.cls_grid, self.regr_grid), - (None, self.regr_grid), - ] - for i, tup_input in enumerate([(None, None), ()]): - with self.subTest("Sampler init", iter=i): - samp = Sampler(*tup_input) - - def test_mcmc(self): - self.train_everything_grid(["rbf"], None) - samp = Sampler(classifier=self.cls_grid, regressor=None) - - steps, acc, rej = samp.run_MCMC( - 15, 0.25, [0, 0, 0], samp.TD_classification, "rbf", T=1, **{"TD_BETA": 2} - ) - self.assertTrue(len(steps) == (acc + 1), msg="steps taken should match acc.") - - def test_ptmcmc(self): - self.train_everything_grid(["rbf"], ["rbf"]) - samp = Sampler(classifier=self.cls_grid, regressor=self.regr_grid) - - chain_step_hist, T_list = samp.run_PTMCMC( - 5, - 15, - samp.TD_classification_regression, - ("rbf", "rbf"), - init_pos=[0, 0, 0], - alpha=0.25, - verbose=False, - trace_plots=False, - TD_BETA=1, - ) - # try with default values - chain_step_hist, T_list = samp.run_PTMCMC( - 10, 15, samp.TD_classification_regression, ("rbf", "rbf"), - verbose=False, trace_plots=False) - - - def test_simple_density_logic(self): - self.cls_grid.train("rbf") - samp = Sampler(classifier=self.cls_grid, regressor=None) - steps, acc, rej = samp.run_MCMC( - 200, 0.25, [0, 0, 0], samp.TD_classification, "rbf", T=1 - ) - acc_pts, rej_pts = samp.do_simple_density_logic(steps, 10, 0.05) - return samp, steps - - def test_get_proposed_points(self): - N = 10 - samp, step_hist = self.test_simple_density_logic() - prop_points, kappa = samp.get_proposed_points(step_hist, N, 0.046) - self.assertTrue(len(prop_points) == N) - - @unittest.skipIf(SKIP_TEST_PLOTS, "Plotting C, C+R target distributions") - def test_TD_plots(self): - self.train_everything_grid(["rbf"], ["rbf"]) - samp = Sampler(classifier=self.cls_grid, regressor=self.regr_grid) - - N = 70 if SHOW_PLOTS else 5 - zed = 0 - x, y = np.meshgrid(np.linspace(-1, 1, N), np.linspace(-1, 1, N)) - z = np.ones(x.shape) * zed - data_points = np.concatenate( - (x.flatten()[:, None], y.flatten()[:, None], z.flatten()[:, None]), axis=1 - ) - - max_probs, pos, cls_keys = samp.get_TD_classification_data("rbf", data_points) - - kwargs = {"TD_BETA": 2, "TD_TAU": 0.5} - cls_regr_td_vals = [ - float(samp.TD_classification_regression(["rbf", "rbf"], dat, **kwargs)) - for dat in data_points - ] - - fig, subs = plt.subplots(1, 2, figsize=(13, 5)) - subs[0].set_title("TD_classification at z = {}".format(zed)) - cls_plot = subs[0].pcolormesh( - x, y, (1 - max_probs).reshape(N, N), shading="auto" - ) - - subs[1].set_title("TD_classification_regression at z = {}".format(zed)) - cls_regr_plot = subs[1].pcolormesh( - x, y, np.array(cls_regr_td_vals).reshape(N, N), shading="auto" - ) - - fig.colorbar(cls_plot, ax=subs[0]) - fig.colorbar(cls_regr_plot, ax=subs[1]) - - for i in range(2): - subs[i].set_xlabel("input_1") - subs[1].set_ylabel("input_2") - subs[i].axis("equal") - - if SHOW_PLOTS: - plt.show() - plt.close() - - -if __name__ == "__main__": - unittest.main() diff --git a/posydon/tests/active_learning/psy_cris/test_TableData.py b/posydon/tests/active_learning/psy_cris/test_TableData.py deleted file mode 100644 index ae1795e4e8..0000000000 --- a/posydon/tests/active_learning/psy_cris/test_TableData.py +++ /dev/null @@ -1,247 +0,0 @@ -"""Unit test for posydon.active_learning.psy_cris classes -""" -import unittest - -import matplotlib.pyplot as plt -import numpy as np -import pandas as pd - -from posydon.active_learning.psy_cris.data import TableData -from posydon.active_learning.psy_cris.synthetic_data.synth_data_3D import get_output_3D -from posydon.active_learning.psy_cris.utils import ( - get_random_grid_df, - get_regular_grid_df, -) - -SKIP_TEST_PLOTS = False -SHOW_PLOTS = False - - -class TestTableData(unittest.TestCase): - """Test TableData class on the 3d synthetic data set.""" - - @classmethod - def setUpClass(cls): - np.random.seed(12345) - cls.TEST_DATA_GRID = get_regular_grid_df(N=10 ** 3, dim=3) - cls.TEST_DATA_RAND = get_random_grid_df(N=10 ** 3, dim=3) - cls.UNIQUE_CLASSES = [1, 2, 3, 4, 6, 8] - - def create_TableData(self, data_frame, **kwargs): - files = None - input_cols = ["input_1", "input_2", "input_3"] - output_cols = ["class", "output_1"] - class_col_name = "class" - table_obj = TableData( - files, - input_cols, - output_cols, - class_col_name, - my_DataFrame=data_frame, - verbose=False, - **kwargs - ) - return table_obj - - def test_init_0(self): - td = self.create_TableData(self.TEST_DATA_GRID) - self.assertTrue(isinstance(td, TableData)) - - def test_init_1_grid(self): - my_kwargs = {"n_neighbors": [2, 3]} - table_grid = self.create_TableData(self.TEST_DATA_GRID, **my_kwargs) - # N classes - self.assertTrue( - table_grid.num_classes == len(self.UNIQUE_CLASSES), - msg="Should find 6 classes. Found {}".format(table_grid.num_classes), - ) - # Unique classes + APC cols - regr_data = table_grid.get_regr_data(what_data="output") - for cls_key in regr_data.keys(): - with self.subTest("Checking data by class.", cls_key=cls_key): - self.assertIn(cls_key, self.UNIQUE_CLASSES) - self.assertIn("APC2_output_1", regr_data[cls_key].columns) - if cls_key != 2: - self.assertIn("APC3_output_1", regr_data[cls_key].columns) - - def test_init_2_rand(self): - my_kwargs = {"n_neighbors": [2, 3]} - table_rand = self.create_TableData(self.TEST_DATA_RAND, **my_kwargs) - - # N classes - self.assertTrue( - table_rand.num_classes == len(self.UNIQUE_CLASSES), - msg="Should find 6 classes. Found {}".format(table_rand.num_classes), - ) - # Unique classes + APC cols - regr_data = table_rand.get_regr_data(what_data="output") - for cls_key in regr_data.keys(): - with self.subTest("Checking data by class.", cls_key=cls_key): - self.assertIn(cls_key, self.UNIQUE_CLASSES) - self.assertIn("APC2_output_1", regr_data[cls_key].columns) - self.assertIn("APC3_output_1", regr_data[cls_key].columns) - - def test_init_3_clean_data(self): - my_kwargs = {"n_neighbors": [2, 3], "omit_vals": [-1]} - table_grid = self.create_TableData(self.TEST_DATA_GRID, **my_kwargs) - - self.assertTrue( - len(table_grid.get_data()) == 729, - msg="Should remove 271 rows from grid data set with value -1.", - ) - - def test_init_4_clean_data(self): - my_kwargs = {"n_neighbors": [2, 3], "omit_vals": [-1]} - table_rand = self.create_TableData(self.TEST_DATA_RAND, **my_kwargs) - - self.assertTrue( - len(table_rand.get_data()) == 1000, - msg="Should remove 0 rows from random data set with value -1.", - ) - - def test_classification_data(self): - my_kwargs = {"n_neighbors": [2, 3], "omit_vals": [-1]} - table_grid = self.create_TableData(self.TEST_DATA_GRID, **my_kwargs) - - binary_classification_data = table_grid.get_binary_mapping_per_class() - self.assertTrue( - binary_classification_data.shape == (len(self.UNIQUE_CLASSES), 729) - ) - self.assertTrue( - all( - [ - sum((row == 1) + (row == 0)) == 729 - for row in binary_classification_data - ] - ) - ) - - def test_regression_data_1(self): - my_kwargs = { - "n_neighbors": [2, 3], - } - table_grid = self.create_TableData(self.TEST_DATA_GRID, **my_kwargs) - output_dat = table_grid.get_regr_data(what_data="output") - for cls in [5, 7]: - with self.subTest(cls=cls): - # Raise KeyError for classes that shouldn't exist - with self.assertRaisesRegex(KeyError, str(cls)): - output_dat[cls] - - def test_regression_data_2(self): - my_kwargs = {"n_neighbors": [2, 3, 5, 10, 50]} - table_grid = self.create_TableData(self.TEST_DATA_GRID, **my_kwargs) - table_rand = self.create_TableData(self.TEST_DATA_RAND, **my_kwargs) - - for key, val in table_grid._regr_dfs_per_class_.items(): - col_names = list(val.columns) - with self.subTest(key=key, data="grid"): - self.assertTrue(any(["APC2" in item for item in col_names])) - if key in [1, 3, 4, 8]: - self.assertTrue(any(["APC50" in item for item in col_names])) - - for key, val in table_rand._regr_dfs_per_class_.items(): - col_names = list(val.columns) - with self.subTest(key=key, data="random"): - self.assertTrue(any(["APC2" in item for item in col_names])) - self.assertTrue(any(["APC3" in item for item in col_names])) - if key in [1, 3, 6, 8]: - self.assertTrue(any(["APC50" in item for item in col_names])) - - @unittest.skipIf( SKIP_TEST_PLOTS, "Test by plotting nearest neighbors skipped") - def test_nearest_neighbhors(self): - my_kwargs = { - "n_neighbors": [2, 3], - } - table_grid = self.create_TableData(self.TEST_DATA_GRID, **my_kwargs) - n_neigh = 3 - - dat = np.random.uniform(low=(-1, -1), high=(1, 1), size=(15, 2)) - output = table_grid.find_n_neighbors(dat, [n_neigh]) - - plt.figure(figsize=(4, 4), dpi=100) - plt.title("NearestNeighbors test") - plt.plot( - dat.T[0][0], dat.T[1][0], "+", color="r", markersize=10, label="reference" - ) - plt.scatter(dat.T[0], dat.T[1], label="data") - for i in range(n_neigh): - plt.scatter( - dat.T[0][output[n_neigh][0][i]], - dat.T[1][output[n_neigh][0][i]], - marker="+", - color="lime", - s=29, - label="nearest", - ) - plt.axis("equal") - if SHOW_PLOTS: - plt.show() - else: - print("To show plots set SHOW_PLOTS to True.") - plt.close() - - @unittest.skipIf( SKIP_TEST_PLOTS, "Test for general plotting skipped") - def test_plotting_1(self): - my_kwargs = { - "n_neighbors": [2, 3], - } - table_grid = self.create_TableData(self.TEST_DATA_GRID, **my_kwargs) - table_rand = self.create_TableData(self.TEST_DATA_RAND, **my_kwargs) - - zed_val = 1 - N = 40 - X, Y = np.meshgrid(np.linspace(-1, 1, N), np.linspace(-1, 1, N)) - Z = np.ones(X.shape) * zed_val - f_out = get_output_3D(X, Y, Z) - print("ZED VAL: {}".format(zed_val)) - - fig, subs = plt.subplots(1, 3, figsize=(14, 4), dpi=100) - subs[0].set_title("TableData - even grid") - fig, subs[0], handles = table_grid.make_class_data_plot( - fig, - subs[0], - ["input_1", "input_2"], - my_slice_vals={0: (0.9, 1.1)}, - return_legend_handles=True, - ) - subs[0].legend( - handles, table_grid._unique_class_keys_, bbox_to_anchor=(-0.25, 0.5) - ) - - subs[1].set_title("TableData - random points") - fig, subs[1], handles = table_rand.make_class_data_plot( - fig, - subs[1], - ["input_1", "input_2"], - my_slice_vals={0: (0.9, 1.1)}, - return_legend_handles=True, - ) - - subs[2].set_title("Analytic Classification") - subs[2].pcolormesh(X, Y, f_out["class"].values.reshape(N, N), shading="auto") - for i in range(3): - subs[i].axis("equal") - if SHOW_PLOTS: - plt.show() - else: - print("To show plots set SHOW_PLOTS to True.") - plt.close() - - @unittest.skipIf( SKIP_TEST_PLOTS, "Test for plotting 3d skipped") - def test_plotting_2(self): - my_kwargs = { - "n_neighbors": [2, 3], - } - table_grid = self.create_TableData(self.TEST_DATA_GRID, **my_kwargs) - table_grid.plot_3D_class_data() - plt.title("plot_3D_class_data") - if SHOW_PLOTS: - plt.show() - else: - print("To show plots set SHOW_PLOTS to True.") - plt.close() - - -if __name__ == "__main__": - unittest.main() diff --git a/posydon/tests/active_learning/psy_cris/test_utils.py b/posydon/tests/active_learning/psy_cris/test_utils.py deleted file mode 100644 index bd666c1ae2..0000000000 --- a/posydon/tests/active_learning/psy_cris/test_utils.py +++ /dev/null @@ -1,76 +0,0 @@ -"""Unit test for posydon.active_learning.psy_cris classes -""" -import os -import unittest - -import numpy as np -import pandas as pd - -from posydon.active_learning.psy_cris.utils import ( - check_dist, - get_new_query_points, - get_random_grid_df, - get_regular_grid_df, - parse_inifile, -) -from posydon.config import PATH_TO_POSYDON - - -class TestUtils(unittest.TestCase): - """Test methods in utils.""" - - @classmethod - def setUpClass(cls): - np.random.seed(12345) - psy_cris_dir = os.path.join( - PATH_TO_POSYDON, "posydon/active_learning/psy_cris") - cls.INI_FILE_PATH = os.path.join(psy_cris_dir, - "run_params/psycris_default.ini") - - def test_parse_inifile(self): - self.assertTrue(os.path.isfile(self.INI_FILE_PATH), msg="Can't find file.") - my_kwargs = parse_inifile(self.INI_FILE_PATH) - self.assertTrue(isinstance(my_kwargs, dict)) - return my_kwargs - - def test_get_new_query_points(self): - my_kwargs = self.test_parse_inifile() - holder = my_kwargs["TableData_kwargs"] - holder["my_DataFrame"] = get_regular_grid_df(N=10 ** 3, dim=3) - my_kwargs["TableData_kwargs"] = holder - - holder_1 = my_kwargs["Sampler_kwargs"] - holder_1["N_tot"] = 50 - holder_1["T_max"] = 5 - holder_1["verbose"] = False - my_kwargs["Sampler_kwargs"] = holder_1 - - query_pts, preds = get_new_query_points(3, **my_kwargs) - self.assertTrue(len(query_pts) == 3) - - def test_check_dist(self): - original_pts = np.random.uniform( - low=(-1, -1, -1), high=(1, 1, 1), size=(500, 3) - ) - proposed_pts = get_regular_grid_df(N=10 ** 3, dim=3).values[:, 0:3] - result = check_dist(original_pts, proposed_pts, threshold=1e-2) - self.assertTrue( - sum(result) == len(proposed_pts), - msg="All points should not be within 1e-2 of eachother.", - ) - - def test_get_regular_grid_df(self): - for config in [dict(N=1000, dim=3), dict(N=50, dim=2), dict(jitter=True)]: - with self.subTest("regular_grid_df", config=config): - df = get_regular_grid_df(**config) - self.assertTrue(isinstance(df, pd.DataFrame)) - - def test_get_random_grid_df(self): - for config in [dict(N=1000, dim=3), dict(N=50, dim=2)]: - with self.subTest("random_grid_df", config=config): - df = get_random_grid_df(**config) - self.assertTrue(isinstance(df, pd.DataFrame)) - - -if __name__ == "__main__": - unittest.main() diff --git a/posydon/tests/binary_evol/CE/test_CEE.py b/posydon/tests/binary_evol/CE/test_CEE.py deleted file mode 100644 index ddeea07647..0000000000 --- a/posydon/tests/binary_evol/CE/test_CEE.py +++ /dev/null @@ -1,655 +0,0 @@ -import os -import unittest - -import numpy as np - -from posydon.binary_evol.binarystar import BinaryStar -from posydon.binary_evol.CE.step_CEE import StepCEE -from posydon.binary_evol.singlestar import SingleStar -from posydon.config import PATH_TO_POSYDON -from posydon.utils import common_functions as cf - -# spaces are read '\\ ' instead of ' ' -PATH_TO_DATA = os.path.join( - PATH_TO_POSYDON, "posydon/tests/data/POSYDON-UNIT-TESTS/binary_evol/CE/") - - -class TestCommonEnvelope(unittest.TestCase): - def test_common_envelope_1(self): - kwargs = {'prescription': 'alpha-lambda', - "common_envelope_option_for_lambda" : 'default_lambda'} - - CEE = StepCEE(verbose=False, **kwargs) - - # simple binary system which will experience CEE with default_lambda - # option on. - # no profiles needed for this - PROPERTIES_STAR1 = { - 'mass': 10.0, - 'log_R': np.log10(1000.0), - 'he_core_mass': 3.0, - 'he_core_radius': 0.5, - 'state': 'H-rich_Shell_H_burning', - 'metallicity' : 0.0142, - "log_Lnuc": -1e6, # arbitrary - "log_LHe": -1e7, # arbitrary - "center_he4" : 0.0, - "center_h1" : 1.0, - "center_c12" : 0.01, - } - PROPERTIES_STAR2 = { - 'mass': 2.0, - 'log_R': np.log10(2.0), - 'he_core_mass': 0.0, - 'he_core_radius': 0.0, - 'state': 'H-rich_Core_H_burning', - 'metallicity' : 0.0142, - "log_Lnuc": -1e6, # arbitrary - "log_LHe": -1e7, # arbitrary - "center_he4" : 0.49, - "center_h1" : 0.49, - "center_c12" : 0.01, - } - giantstar = SingleStar(**PROPERTIES_STAR1) - compstar = SingleStar(**PROPERTIES_STAR2) - - orbital_separation_for_RLOF = 10**giantstar.log_R / cf.roche_lobe_radius( - giantstar.mass, compstar.mass, a_orb=1) - orbital_period_for_RLOF = cf.orbital_period_from_separation( - orbital_separation_for_RLOF, giantstar.mass, compstar.mass) - PROPERTIES_BINARY = { - "binary_state": "RLO1", - "event": "oCE1", - "orbital_period": orbital_period_for_RLOF - } - binary = BinaryStar(star_1=giantstar, - star_2=compstar, - **PROPERTIES_BINARY) - - CEE(binary) - #self.assertTrue(binary.event == 'redirect', "CEE test 1 failed") - self.assertTrue( - abs(binary.orbital_period - 5.056621408721529) < - 1.0, "CEE test 1 failed") - self.assertTrue("stripped_He" in binary.star_1.state, "CEE test 1 failed") - - def test_common_envelope_2(self): - kwargs = {'prescription': 'alpha-lambda', - "common_envelope_option_for_lambda" : 'lambda_from_profile_gravitational'} - - CEE = StepCEE(verbose=False, **kwargs) - - # testing with loading a profile of the donor at the moment of CEE to - # calculate the lamda CEE - profile_donor_name = os.path.join(PATH_TO_DATA, - 'simple_giant_profile_for_CEE.npy') - profile_donor = np.load(profile_donor_name) - PROPERTIES_STAR1_withprofile = { - 'mass': 22.77, - 'log_R': np.log10(1319.0), - 'he_core_mass': 11.0, - 'he_core_radius': 0.6, - 'state': 'H-rich_Shell_H_burning', - 'metallicity' : 0.0142, - 'profile': profile_donor, - "log_Lnuc": -1e6, # arbitrary - "log_LHe": -1e7, # arbitrary - "center_he4" : 0.0, - "center_h1" : 1.0, - "center_c12" : 0.01, - } - giantstar_withprofile = SingleStar(**PROPERTIES_STAR1_withprofile) - PROPERTIES_STAR2 = { - 'mass': 2.0, - 'log_R': np.log10(2.0), - 'he_core_mass': 0.0, - 'he_core_radius': 0.0, - 'state': 'H-rich_Core_H_burning', - 'metallicity' : 0.0142, - "log_Lnuc": -1e6, # arbitrary - "log_LHe": -1e7, # arbitrary - "center_he4" : 0.5, - "center_h1" : 0.5, - "center_c12" : 0.01, - } - compstar = SingleStar(**PROPERTIES_STAR2) - - orbital_separation_for_RLOF = 10**giantstar_withprofile.log_R / cf.roche_lobe_radius( - giantstar_withprofile.mass, compstar.mass, a_orb=1) - orbital_period_for_RLOF = cf.orbital_period_from_separation( - orbital_separation_for_RLOF, giantstar_withprofile.mass, - compstar.mass) - - PROPERTIES_BINARY_withprofile = { - "binary_state": "RLO1", - "event": "oCE1", - "orbital_period": orbital_period_for_RLOF - } - binary_withprofile = BinaryStar(giantstar_withprofile, compstar, - **PROPERTIES_BINARY_withprofile) - # options: 'default_lambda', 'lambda_from_profile_gravitational', - # 'lambda_from_profile_gravitational_plus_internal', - # 'lambda_from_profile_gravitational_plus_internal_minus_recombination' - #binary_withprofile.properties.common_envelope_option_for_lambda = "lambda_from_profile_gravitational" - CEE(binary_withprofile) - #print(binary_withprofile.event) - self.assertTrue(binary_withprofile.state == "merged", - "CEE test 2 failed") - - def test_common_envelope_3(self): - kwargs = {'prescription': 'alpha-lambda', - "common_envelope_option_for_lambda" : 'lambda_from_profile_gravitational_plus_internal_minus_recombination'} - - CEE = StepCEE(verbose=False, **kwargs) - - # testing with loading a profile of the donor at the moment of CEE to - # calculate the lamda CEE, taking into account also the internal energy - # - recombination energy - profile_donor_name = os.path.join( - PATH_TO_DATA, - 'giant_profile_for_CEE_with_recombinationenergy_calculation.npy') - #profile_donor = np.load(profile_donor_name, mmap_mode = "r") - profile_donor = np.load(profile_donor_name) - PROPERTIES_STAR1_withprofile = { - 'mass': 22.77, - 'log_R': np.log10(1319.0), - 'he_core_mass': 11.0, - 'he_core_radius': 0.6, - 'state': 'H-rich_Shell_H_burning', - 'metallicity' : 0.0142, - 'profile': profile_donor, - "log_Lnuc": -1e6, # arbitrary - "log_LHe": -1e7, # arbitrary - "center_he4" : 0.0, - "center_h1" : 1.0, - "center_c12" : 0.01, - } - giantstar_withprofile = SingleStar(**PROPERTIES_STAR1_withprofile) - PROPERTIES_STAR2 = { - 'mass': 2.0, - 'log_R': np.log10(2.0), - 'he_core_mass': 0.0, - 'he_core_radius': 0.0, - 'state': 'H-rich_Core_H_burning', - 'metallicity' : 0.0142, - "log_Lnuc": -1e6, # arbitrary - "log_LHe": -1e7, # arbitrary - "center_he4" : 0.49, - "center_h1" : 0.49, - "center_c12" : 0.01, - } - compstar = SingleStar(**PROPERTIES_STAR2) - - orbital_separation_for_RLOF = 10**giantstar_withprofile.log_R / cf.roche_lobe_radius( - giantstar_withprofile.mass, compstar.mass, a_orb=1) - orbital_period_for_RLOF = cf.orbital_period_from_separation( - orbital_separation_for_RLOF, giantstar_withprofile.mass, - compstar.mass) - - PROPERTIES_BINARY_withprofile = { - "binary_state": "RLO1", - "event": "oCE1", - "orbital_period": orbital_period_for_RLOF - } - binary_withprofile = BinaryStar(giantstar_withprofile, compstar, - **PROPERTIES_BINARY_withprofile) - # options: 'default_lambda', 'lambda_from_profile_gravitational', - # 'lambda_from_profile_gravitational_plus_internal', - # 'lambda_from_profile_gravitational_plus_internal_minus_recombination' - #binary_withprofile.properties.common_envelope_option_for_lambda = "lambda_from_profile_gravitational_plus_internal_minus_recombination" - #print(binary_withprofile.properties.common_envelope_option_for_lambda) - CEE(binary_withprofile) - #print(binary_withprofile.event) - self.assertTrue(binary_withprofile.state == "merged", - "CEE test 3 failed") - - def test_common_envelope_4(self): - kwargs = {'prescription': 'alpha-lambda', - "common_envelope_option_for_lambda" : 'lambda_from_profile_gravitational'} - - CEE = StepCEE(verbose=False, **kwargs) - - profile_donor_name = os.path.join(PATH_TO_DATA, - 'simple_giant_profile_for_CEE.npy') - profile_donor = np.load(profile_donor_name) - PROPERTIES_STAR1_withprofile = { - 'mass': 22.77, - 'log_R': np.log10(1319.0), - 'he_core_mass': 11.0, - 'he_core_radius': 0.6, - 'state': 'H-rich_Shell_H_burning', - 'metallicity' : 0.0142, - 'profile': profile_donor, - "log_Lnuc": -1e6, # arbitrary - "log_LHe": -1e7, # arbitrary - "center_he4" : 0.0, - "center_h1" : 1.0, - "center_c12" : 0.01, - } - giantstar_withprofile = SingleStar(**PROPERTIES_STAR1_withprofile) - PROPERTIES_STAR2 = { - 'mass': 10.0, - 'log_R': np.log10(7.0), - 'he_core_mass': 0.0, - 'he_core_radius': 0.0, - 'state': 'H-rich_Core_H_burning', - 'metallicity' : 0.0142, - "log_Lnuc": -1e6, # arbitrary - "log_LHe": -1e7, # arbitrary - "center_he4" : 0.49, - "center_h1" : 0.49, - "center_c12" : 0.01, - } - compstar = SingleStar(**PROPERTIES_STAR2) - - orbital_separation_for_RLOF = 10**giantstar_withprofile.log_R / cf.roche_lobe_radius( - giantstar_withprofile.mass, compstar.mass, a_orb=1) - orbital_period_for_RLOF = cf.orbital_period_from_separation( - orbital_separation_for_RLOF, giantstar_withprofile.mass, - compstar.mass) - - PROPERTIES_BINARY_withprofile = { - "binary_state": "RLO1", - "event": "oCE1", - "orbital_period": orbital_period_for_RLOF - } - binary_withprofile = BinaryStar(giantstar_withprofile, compstar, - **PROPERTIES_BINARY_withprofile) - # options: 'default_lambda', 'lambda_from_profile_gravitational', - # 'lambda_from_profile_gravitational_plus_internal', - # 'lambda_from_profile_gravitational_plus_internal_minus_recombination' - #binary_withprofile.properties.common_envelope_option_for_lambda = "lambda_from_profile_gravitational" - #print(binary_withprofile.properties.common_envelope_option_for_lambda) - CEE(binary_withprofile) - #print(binary_withprofile.event) - self.assertTrue(binary_withprofile.state == "merged", - "CEE test 4 failed") - - def test_common_envelope_5(self): - kwargs = {'prescription': 'alpha-lambda', - "common_envelope_option_for_lambda" : 'lambda_from_profile_gravitational_plus_internal_minus_recombination'} - - CEE = StepCEE(verbose=False, **kwargs) - # testing with loading a profile of the donor at the moment of CEE to - # calculate the lamda CEE, taking into account also the internal - # energy - recombination energy - profile_donor_name = os.path.join( - PATH_TO_DATA, - 'giant_profile_for_CEE_with_recombinationenergy_calculation.npy') - profile_donor = np.load(profile_donor_name) - PROPERTIES_STAR1_withprofile = { - 'mass': 22.77, - 'log_R': np.log10(1319.0), - 'he_core_mass': 11.0, - 'he_core_radius': 0.6, - 'state': 'H-rich_Shell_H_burning', - 'metallicity' : 0.0142, - 'profile': profile_donor, - "log_Lnuc": -1e6, # arbitrary - "log_LHe": -1e7, # arbitrary - "center_he4" : 0.0, - "center_h1" : 1.0, - "center_c12" : 0.01, - } - giantstar_withprofile = SingleStar(**PROPERTIES_STAR1_withprofile) - PROPERTIES_STAR2 = { - 'mass': 10.0, - 'log_R': np.log10(7.0), - 'he_core_mass': 0.0, - 'he_core_radius': 0.0, - 'state': 'H-rich_Core_H_burning', - 'metallicity' : 0.0142, - "log_Lnuc": -1e6, # arbitrary - "log_LHe": -1e7, # arbitrary - "center_he4" : 0.49, - "center_h1" : 0.49, - "center_c12" : 0.01, - } - compstar = SingleStar(**PROPERTIES_STAR2) - - orbital_separation_for_RLOF = 10**giantstar_withprofile.log_R / cf.roche_lobe_radius( - giantstar_withprofile.mass, compstar.mass, a_orb=1) - orbital_period_for_RLOF = cf.orbital_period_from_separation( - orbital_separation_for_RLOF, giantstar_withprofile.mass, - compstar.mass) - - PROPERTIES_BINARY_withprofile = { - "binary_state": "RLO1", - "event": "oCE1", - "orbital_period": orbital_period_for_RLOF - } - binary_withprofile = BinaryStar(giantstar_withprofile, compstar, - **PROPERTIES_BINARY_withprofile) - # options: 'default_lambda', 'lambda_from_profile_gravitational', - # 'lambda_from_profile_gravitational_plus_internal', - # 'lambda_from_profile_gravitational_plus_internal_minus_recombination' - #binary_withprofile.properties.common_envelope_option_for_lambda = "lambda_from_profile_gravitational_plus_internal_minus_recombination" - #print(binary_withprofile.properties.common_envelope_option_for_lambda) - CEE(binary_withprofile) - #print(binary_withprofile.event) - self.assertTrue(binary_withprofile.state == "merged", - "CEE test 5 failed") - - def test_common_envelope_6(self): - kwargs = {'prescription': 'alpha-lambda', - "common_envelope_option_for_lambda" : 'lambda_from_profile_gravitational'} - - CEE = StepCEE(verbose=False, **kwargs) - - profile_donor_name = os.path.join(PATH_TO_DATA, - 'simple_giant_profile_for_CEE.npy') - #profile_donor = np.genfromtxt(profile_donor_name, skip_header=5, names=True, dtype=None) - profile_donor = np.load(profile_donor_name) - PROPERTIES_STAR1_withprofile = { - 'mass': 22.77, - 'log_R': np.log10(1319.0), - 'he_core_mass': 11.0, - 'he_core_radius': 0.6, - 'state': 'H-rich_Shell_H_burning', - 'metallicity' : 0.0142, - 'profile': profile_donor, - "log_Lnuc": -1e6, # arbitrary - "log_LHe": -1e7, # arbitrary - "center_he4" : 0.0, - "center_h1" : 1.0, - "center_c12" : 0.01, - } - giantstar_withprofile = SingleStar(**PROPERTIES_STAR1_withprofile) - PROPERTIES_STAR2 = { - 'mass': 10., - 'log_R': np.log10(0.0001), - 'he_core_mass': 0.0, - 'he_core_radius': 0.0, - 'state': 'BH' - } - compstar = SingleStar(**PROPERTIES_STAR2) - - orbital_separation_for_RLOF = 10**giantstar_withprofile.log_R / cf.roche_lobe_radius( - giantstar_withprofile.mass, compstar.mass, a_orb=1) - orbital_period_for_RLOF = cf.orbital_period_from_separation( - orbital_separation_for_RLOF, giantstar_withprofile.mass, - compstar.mass) - - PROPERTIES_BINARY_withprofile = { - "binary_state": "RLO1", - "event": "oCE1", - "orbital_period": orbital_period_for_RLOF - } - binary_withprofile = BinaryStar(giantstar_withprofile, compstar, - **PROPERTIES_BINARY_withprofile) - #binary_withprofile.properties.common_envelope_option_for_lambda = "lambda_from_profile_gravitational" # options: 'default_lambda', 'lambda_from_profile_gravitational', 'lambda_from_profile_gravitational_plus_internal', 'lambda_from_profile_gravitational_plus_internal_minus_recombination' - #print(binary_withprofile.properties.common_envelope_option_for_lambda) - CEE(binary_withprofile) - #print("new state of the star that triggered CEE = ",giantstar_withprofile.state) - #print("new mass of the star that triggered CEE = ",giantstar_withprofile.mass) - print(binary_withprofile.event) - #self.assertTrue(binary_withprofile.event == 'redirect', - # "CEE test 6 failed") - self.assertTrue((10**giantstar_withprofile.log_R - cf.roche_lobe_radius( - giantstar_withprofile.mass, compstar.mass, - a_orb=cf.orbital_separation_from_period( - binary_withprofile.orbital_period, giantstar_withprofile.mass, - compstar.mass))), "CEE test 6 failed") - - self.assertTrue( - (abs(binary_withprofile.orbital_period - 0.12123905531545925) < - 1.0), - "CEE test 6 failed") - - def test_common_envelope_7(self): - kwargs = {'prescription': 'alpha-lambda', - "common_envelope_option_for_lambda" : 'lambda_from_profile_gravitational_plus_internal_minus_recombination'} - - CEE = StepCEE(verbose=False, **kwargs) - - # testing with loading a profile of the donor at the moment of CEE to - # calculate the lamda CEE, taking into account also the internal - # energy - recombination energy - profile_donor_name = os.path.join( - PATH_TO_DATA, - 'giant_profile_for_CEE_with_recombinationenergy_calculation.npy') - profile_donor = np.load(profile_donor_name) - PROPERTIES_STAR1_withprofile = { - 'mass': 22.77, - 'log_R': np.log10(1319.0), - 'he_core_mass': 11.0, - 'he_core_radius': 0.6, - 'state': 'H-rich_Shell_H_burning', - 'metallicity' : 0.0142, - 'profile': profile_donor, - "log_Lnuc": -1e6, # arbitrary - "log_LHe": -1e7, # arbitrary - "center_he4" : 0.0, - "center_h1" : 1.0, - "center_c12" : 0.01, - } - giantstar_withprofile = SingleStar(**PROPERTIES_STAR1_withprofile) - PROPERTIES_STAR2 = { - 'mass': 10., - 'log_R': np.log10(0.0001), - 'he_core_mass': 0.0, - 'he_core_radius': 0.0, - 'state': 'BH' - } - compstar = SingleStar(**PROPERTIES_STAR2) - - orbital_separation_for_RLOF = 10**giantstar_withprofile.log_R / cf.roche_lobe_radius( - giantstar_withprofile.mass, compstar.mass, a_orb=1) - orbital_period_for_RLOF = cf.orbital_period_from_separation( - orbital_separation_for_RLOF, giantstar_withprofile.mass, - compstar.mass) - - PROPERTIES_BINARY_withprofile = { - "binary_state": "RLO1", - "event": "oCE1", - "orbital_period": orbital_period_for_RLOF - } - binary_withprofile = BinaryStar(giantstar_withprofile, compstar, - **PROPERTIES_BINARY_withprofile) - # options: 'default_lambda', 'lambda_from_profile_gravitational', - # 'lambda_from_profile_gravitational_plus_internal', - # 'lambda_from_profile_gravitational_plus_internal_minus_recombination' - #binary_withprofile.properties.common_envelope_option_for_lambda = "lambda_from_profile_gravitational_plus_internal_minus_recombination" - #print(binary_withprofile.properties.common_envelope_option_for_lambda) - CEE(binary_withprofile) - #print(binary_withprofile.event) - #self.assertTrue(binary_withprofile.event == 'redirect', - # "CEE test 7 failed") - self.assertTrue((10**giantstar_withprofile.log_R - cf.roche_lobe_radius( - giantstar_withprofile.mass, compstar.mass, - a_orb=cf.orbital_separation_from_period( - binary_withprofile.orbital_period, giantstar_withprofile.mass, - compstar.mass))), "CEE test 7 failed") - self.assertTrue( - (abs(binary_withprofile.orbital_period - 0.3287114957064215) < - 1.0), - "CEE test 7 failed") - - def test_common_envelope_8(self): - kwargs = {'prescription': 'alpha-lambda', - "common_envelope_option_for_lambda" : 'lambda_from_profile_gravitational_plus_internal_minus_recombination'} - - CEE = StepCEE(verbose=False, **kwargs) - - profile_donor_name = os.path.join( - PATH_TO_DATA, - 'giant_profile_for_CEE_with_recombinationenergy_calculation.npy') - profile_donor = np.load(profile_donor_name) - PROPERTIES_STAR1_withprofile = { - 'mass': 22.77, - 'log_R': np.log10(1319.0), - 'he_core_mass': 11.0, - 'he_core_radius': 0.6, - 'state': 'H-rich_Shell_H_burning', - 'metallicity' : 0.0142, - 'profile': profile_donor, - "log_Lnuc": -1e6, # arbitrary - "log_LHe": -1e7, # arbitrary - "center_he4" : 0.0, - "center_h1" : 1.0, - "center_c12" : 0.01, - } - giantstar_withprofile = SingleStar(**PROPERTIES_STAR1_withprofile) - PROPERTIES_STAR2 = { - 'mass': 20., - 'log_R': np.log10(0.0001), - 'he_core_mass': 0.0, - 'he_core_radius': 0.0, - 'state': 'BH' - } #radius = 10km - compstar = SingleStar(**PROPERTIES_STAR2) - - orbital_separation_for_RLOF = 10**giantstar_withprofile.log_R / cf.roche_lobe_radius( - giantstar_withprofile.mass, compstar.mass, a_orb=1) - orbital_period_for_RLOF = cf.orbital_period_from_separation( - orbital_separation_for_RLOF, giantstar_withprofile.mass, - compstar.mass) - - PROPERTIES_BINARY_withprofile = { - "binary_state": "RLO1", - "event": "oCE1", - "orbital_period": orbital_period_for_RLOF - } - binary_withprofile = BinaryStar(giantstar_withprofile, compstar, - **PROPERTIES_BINARY_withprofile) - # options: 'default_lambda', 'lambda_from_profile_gravitational', - # 'lambda_from_profile_gravitational_plus_internal', - # 'lambda_from_profile_gravitational_plus_internal_minus_recombination' - #binary_withprofile.properties.common_envelope_option_for_lambda = "lambda_from_profile_gravitational_plus_internal_minus_recombination" - #print(binary_withprofile.properties.common_envelope_option_for_lambda) - CEE(binary_withprofile) - #print(binary_withprofile.event) - #self.assertTrue(binary_withprofile.event == 'redirect', - # "CEE test 8 failed") - #self.assertTrue(binary_withprofile.star_1.state == "stripped_He_Core_He_burning", - # "CEE test 8 failed") - self.assertTrue("stripped_He" in binary_withprofile.star_1.state, - "CEE test 8 failed") - self.assertTrue( - (abs(binary_withprofile.orbital_period - 0.7636524660283687) < - 1.0), - "CEE test 8 failed") - - def test_common_envelope_9(self): - kwargs = {'prescription': 'alpha-lambda', - "common_envelope_option_for_lambda" : 'lambda_from_profile_gravitational_plus_internal_minus_recombination'} - - CEE = StepCEE(verbose=False, **kwargs) - - profile_donor_name = os.path.join(PATH_TO_DATA, - 'caseB_CEE_profile.npy') - profile_donor = np.load(profile_donor_name) - PROPERTIES_STAR1_withprofile = { - 'mass': 28.04, - 'log_R': np.log10(927.0), - 'he_core_mass': 11.0, - 'he_core_radius': 0.6, - 'state': 'H-rich_Shell_H_burning', - 'metallicity' : 0.0142, - 'profile': profile_donor, - "log_Lnuc": -1e6, # arbitrary - "log_LHe": -1e7, # arbitrary - "center_he4" : 0.0, - "center_h1" : 1.0, - "center_c12" : 0.01, - } - giantstar_withprofile = SingleStar(**PROPERTIES_STAR1_withprofile) - PROPERTIES_STAR2 = { - 'mass': 20., - 'log_R': np.log10(0.0001), - 'he_core_mass': 0.0, - 'he_core_radius': 0.0, - 'state': 'BH' - } #radius = 10km - compstar = SingleStar(**PROPERTIES_STAR2) - - orbital_separation_for_RLOF = 10**giantstar_withprofile.log_R / cf.roche_lobe_radius( - giantstar_withprofile.mass, compstar.mass, a_orb=1) - orbital_period_for_RLOF = cf.orbital_period_from_separation( - orbital_separation_for_RLOF, giantstar_withprofile.mass, - compstar.mass) - - PROPERTIES_BINARY_withprofile = { - "binary_state": "RLO1", - "event": "oCE1", - "orbital_period": orbital_period_for_RLOF - } - binary_withprofile = BinaryStar(giantstar_withprofile, compstar, - **PROPERTIES_BINARY_withprofile) - # options: 'default_lambda', 'lambda_from_profile_gravitational', - # 'lambda_from_profile_gravitational_plus_internal', - # 'lambda_from_profile_gravitational_plus_internal_minus_recombination' - #binary_withprofile.properties.common_envelope_option_for_lambda = "lambda_from_profile_gravitational_plus_internal_minus_recombination" - #print(binary_withprofile.properties.common_envelope_option_for_lambda) - CEE(binary_withprofile) - #print(binary_withprofile.event) - #self.assertTrue(binary_withprofile.event == 'redirect', - # "CEE test 9 failed event") - self.assertTrue("stripped_He" in binary_withprofile.star_1.state, - "CEE test 9 failed state") - self.assertTrue( - (abs(binary_withprofile.orbital_period - 0.166535882054919) < - 1.0), - "CEE test 9 failed tolerance") - - def test_common_envelope_10(self): - kwargs = {'prescription': 'alpha-lambda', - "common_envelope_option_for_lambda" : 'lambda_from_profile_gravitational_plus_internal_minus_recombination'} - - CEE = StepCEE(verbose=False, **kwargs) - - profile_donor_name = os.path.join(PATH_TO_DATA, - 'caseB_CEE_profile.npy') - profile_donor = np.load(profile_donor_name) - PROPERTIES_STAR1_withprofile = { - 'mass': 28.04, - 'log_R': np.log10(927.0), - 'he_core_mass': 11.0, - 'he_core_radius': 0.6, - 'state': 'H-rich_Shell_H_burning', - 'metallicity' : 0.0142, - 'profile': profile_donor, - "log_Lnuc": -1e6, # arbitrary - "log_LHe": -1e7, # arbitrary - "center_he4" : 0.0, - "center_h1" : 1.0, - "center_c12" : 0.01, - } - giantstar_withprofile = SingleStar(**PROPERTIES_STAR1_withprofile) - PROPERTIES_STAR2 = { - 'mass': 2., - 'log_R': np.log10(1.5), - 'he_core_mass': 0.0, - 'he_core_radius': 0.0, - 'state': 'H-rich_Core_H_burning', - 'metallicity' : 0.0142, - "log_Lnuc": -1e6, # arbitrary - "log_LHe": -1e7, # arbitrary - "center_he4" : 0.49, - "center_h1" : 0.49, - "center_c12" : 0.01, - } #radius = 10km - compstar = SingleStar(**PROPERTIES_STAR2) - - orbital_separation_for_RLOF = 10**giantstar_withprofile.log_R / cf.roche_lobe_radius( - giantstar_withprofile.mass, compstar.mass, a_orb=1) - orbital_period_for_RLOF = cf.orbital_period_from_separation( - orbital_separation_for_RLOF, giantstar_withprofile.mass, - compstar.mass) - - PROPERTIES_BINARY_withprofile = { - "binary_state": "RLO1", - "event": "oCE1", - "orbital_period": orbital_period_for_RLOF - } - binary_withprofile = BinaryStar(giantstar_withprofile, compstar, - **PROPERTIES_BINARY_withprofile) - # options: 'default_lambda', 'lambda_from_profile_gravitational', - # 'lambda_from_profile_gravitational_plus_internal', - # 'lambda_from_profile_gravitational_plus_internal_minus_recombination' - #binary_withprofile.properties.common_envelope_option_for_lambda = "lambda_from_profile_gravitational_plus_internal_minus_recombination" - #print(binary_withprofile.properties.common_envelope_option_for_lambda) - CEE(binary_withprofile) - #print(binary_withprofile.event) - self.assertTrue(binary_withprofile.state == "merged", - "CEE test 10 failed") diff --git a/posydon/tests/binary_evol/DT/test_step_detached.py b/posydon/tests/binary_evol/DT/test_step_detached.py deleted file mode 100644 index 00fbb9db43..0000000000 --- a/posydon/tests/binary_evol/DT/test_step_detached.py +++ /dev/null @@ -1,404 +0,0 @@ -import os -import unittest - -from posydon.binary_evol.binarystar import BinaryStar -from posydon.binary_evol.DT.step_detached import detached_step, diffeq -from posydon.binary_evol.simulationproperties import SimulationProperties -from posydon.binary_evol.singlestar import SingleStar -from posydon.config import PATH_TO_POSYDON -from posydon.utils import common_functions as cf -from posydon.utils import constants as const - -PATH_TO_DATA = os.path.join( - PATH_TO_POSYDON, "posydon/tests/data/POSYDON-UNIT-TESTS/binary_evol/detached/") -#eep_version = "POSYDON" - - -class TestDetached_step(unittest.TestCase): - def test_matching1_root(self): - method = "root" - matching = detached_step(#grid='POSYDON', - path=PATH_TO_DATA, - matching_method=method, - #eep_version=eep_version, - verbose=False) - get_mist0 = detached_step.get_mist0 - get_track_val = detached_step.get_track_val - htrack = True - PROPERTIES_STAR = { - "mass": 60.0, - "log_R": 1.0, - "mdot": -(10.0**(-5)), - "state": "H-rich_Core_H_burning", - "center_he4": 0.48, - "center_h1": 0.5, - "total_moment_of_inertia": 10.0**57, - "log_total_angular_momentum": 52, - "he_core_mass": 0.0, - "surface_he4": 0.28, - "surface_h1": 0.7, - } - m0, t = get_mist0(matching, SingleStar(**PROPERTIES_STAR),htrack) - - self.assertAlmostEqual( - m0, - 64.2410914922183, - places=1, - msg= - "Initial mass in MIST matching not exactly what expected. Should be 64.68538051198551", - ) - self.assertAlmostEqual( - get_track_val(matching, "mass",htrack, m0, t), - 60.0000000000419, - places=3, - msg= - "Current mass in matching not exactly what expected. Should be 60.0000000000419", - ) - self.assertAlmostEqual( - get_track_val(matching, "log_R",htrack, m0, t), - 1.1282247490900794, - places=1, - msg= - "Current log_R in matching not exactly what expected. Should be 1.1282247490900794", - ) - self.assertAlmostEqual( - get_track_val(matching, "center_he4",htrack, m0, t), - 0.4861139708172655, - places=2, - msg= - "Current center_he4 in matching not exactly what expected. Should be 0.4861139708172655", - ) - self.assertAlmostEqual( - get_track_val(matching, "he_core_mass",htrack, m0, t), - 0.0, - places=1, - msg= - "Current he_core_mass matching not exactly what expected. Should be 0.0", - ) - - def test_matching1_minimize(self): - method = "minimize" - matching = detached_step(#grid='POSYDON', - path=PATH_TO_DATA, - matching_method=method, - #eep_version=eep_version, - verbose=False) - get_mist0 = detached_step.get_mist0 - get_track_val = detached_step.get_track_val - htrack = True - PROPERTIES_STAR = { - "mass": 60.0, - "log_R": 1.0, - "mdot": -(10.0**(-5)), - "state": "H-rich_Core_H_burning", - "center_he4": 0.48, - "center_h1": 0.5, - "total_moment_of_inertia": 10.0**57, - "log_total_angular_momentum": 52, - "he_core_mass": 0.0, - "surface_he4": 0.28, - "surface_h1": 0.7, - } - m0, t = get_mist0(matching, SingleStar(**PROPERTIES_STAR),htrack) - - #self.assertAlmostEqual( - # m0, - # 62.88453923015954, - # places= - # 1, # less accuracy because we try to fit more alternative parameters than "root" method at the same time - # msg= - # "Initial mass in MIST matching not exactly what expected. Should be 62.78903050084804", - #) - #self.assertAlmostEqual( - # get_track_val(matching, "mass",htrack, m0, t), - # 59.96234634155157, - # places=1, - # msg= - # "Current mass in matching not exactly what expected. Should be 59.96234634155157", - #) - #self.assertAlmostEqual( - # get_track_val(matching, "log_R",htrack, m0, t), - # 1.0973066851601672, - # places=1, - # msg= - # "Current log_R in matching not exactly what expected. Should be 1.0973066851601672", - #) - #self.assertAlmostEqual( - # get_track_val(matching, "center_he4",htrack, m0, t), - # 0.4252496982220549, - # places=1, - # msg= - # "Current center_he4 in matching not exactly what expected. Should be 0.4252496982220549", - #) - self.assertAlmostEqual( - get_track_val(matching, "he_core_mass",htrack, m0, t), - 0.0, - places=1, - msg= - "Current mass he_core_mass matching not exactly what expected. Should be 0.0", - ) - - def test_only_tides(self): - method = "minimize" - matching = detached_step(#grid='POSYDON', - path=PATH_TO_DATA, - matching_method=method, - #eep_version=eep_version, - verbose=False) - step_ODE_minimize_hist = detached_step( - #grid='POSYDON', - path=PATH_TO_DATA, - n_o_steps_history=30, - #eep_version=eep_version, - matching_method=method, - verbose=False, - ) - step_ODE_minimize_hist_onlytides = detached_step( - #grid='POSYDON', - path=PATH_TO_DATA, - n_o_steps_history=30, - matching_method=method, - #eep_version=eep_version, - verbose=False, - do_wind_loss=False, - do_tides=True, - do_gravitational_radiation=False, - do_magnetic_braking=False, - do_stellar_evolution_and_spin_from_winds=False, - ) - PROPERTIES_STAR1 = {"mass": 10.0, "state": "BH"} - LOW_MS_PROPERTIES_STAR2_non_rot = { - "mass": 8.0, - "log_R": 0.6, - "mdot": -(10.0**(-7)), - "state": "H-rich_Core_H_burning", - "center_he4": 0.28, - "center_h1": 0.7, - "total_moment_of_inertia": 10.0**57, - "log_total_angular_momentum": -10.99, # non-rotating - "he_core_mass": 0.0, - "surface_he4": 0.28, - "surface_h1": 0.7, - } - init_orbital_period = 10 - init_separation = cf.orbital_separation_from_period( - init_orbital_period, - PROPERTIES_STAR1["mass"], - LOW_MS_PROPERTIES_STAR2_non_rot["mass"], - ) - CLOSE_BINARY = { - "time": 5 * 10.0**6, - "orbital_period": init_orbital_period, - "separation": init_separation, - "state": "detached", - "eccentricity": 0.0, - "event": "None", - } - - binary = BinaryStar( - star_1=SingleStar(**PROPERTIES_STAR1), - star_2=SingleStar(**LOW_MS_PROPERTIES_STAR2_non_rot), - **CLOSE_BINARY) - binary.properties.max_simulation_time = 10.0**10 - step_ODE_minimize_hist_onlytides(binary) - - self.assertLessEqual( - getattr(binary, "separation_history")[-1], - getattr(binary, "separation_history")[0], - msg= - "final sepertation with tides only and a non-rotating donor should decrease.", - ) - - def test_tides_vs_tides_and_winds(self): - method = "minimize" - matching = detached_step(#grid='POSYDON', - path=PATH_TO_DATA, - matching_method=method, - #eep_version=eep_version, - verbose=False) - step_ODE_minimize_hist_tides_and_winds = detached_step( - #grid='POSYDON', - path=PATH_TO_DATA, - n_o_steps_history=30, - matching_method=method, - #eep_version=eep_version, - verbose=False, - do_wind_loss=True, - do_tides=True, - do_gravitational_radiation=False, - do_magnetic_braking=False, - do_stellar_evolution_and_spin_from_winds=False, - ) - step_ODE_minimize_hist_onlytides = detached_step( - #grid='POSYDON', - path=PATH_TO_DATA, - n_o_steps_history=30, - matching_method=method, - #eep_version=eep_version, - verbose=False, - do_wind_loss=False, - do_tides=True, - do_gravitational_radiation=False, - do_magnetic_braking=False, - do_stellar_evolution_and_spin_from_winds=False, - ) - - PROPERTIES_STAR1 = {"mass": 10.0, "state": "BH"} - LOW_MS_PROPERTIES_STAR2_non_rot = { - "mass": 8.0, - "log_R": 0.6, - "mdot": -(10.0**(-7)), - "state": "H-rich_Core_H_burning", - "center_he4": 0.28, - "center_h1": 0.7, - "total_moment_of_inertia": 10.0**57, - "log_total_angular_momentum": -10.99, # non-rotating - "he_core_mass": 0.0, - "surface_he4": 0.28, - "surface_h1": 0.7, - } - init_orbital_period = 10 - init_separation = cf.orbital_separation_from_period( - init_orbital_period, - PROPERTIES_STAR1["mass"], - LOW_MS_PROPERTIES_STAR2_non_rot["mass"], - ) - CLOSE_BINARY = { - "time": 5 * 10.0**6, - "orbital_period": init_orbital_period, - "separation": init_separation, - "state": "detached", - "eccentricity": 0.0, - "event": "None", - } - - binary = BinaryStar( - star_1=SingleStar(**PROPERTIES_STAR1), - star_2=SingleStar(**LOW_MS_PROPERTIES_STAR2_non_rot), - **CLOSE_BINARY) - binary_test = BinaryStar( - star_1=SingleStar(**PROPERTIES_STAR1), - star_2=SingleStar(**LOW_MS_PROPERTIES_STAR2_non_rot), - **CLOSE_BINARY) - - binary.properties.max_simulation_time = 10.0**10 - binary_test.properties.max_simulation_time = 10.0**10 - - step_ODE_minimize_hist_tides_and_winds(binary) - step_ODE_minimize_hist_onlytides(binary_test) - - self.assertLessEqual( - getattr(binary_test, "separation_history")[-1], - getattr(binary, "separation_history")[-1], - msg= - "final sepertation with tides only and a non-rotating donor should be lower than including winds too that widen the orbit too.", - ) - - # the following tests are out because they need more EEPS MIST models around their mass. If included they should work. - """ - def test_matching2_root(self): - method = "root" - matching = HMS_detached_step(PATH_TO_EEPS, matching_method=method, verbose=True) - get_mist0 = HMS_detached_step.get_mist0 - get_track_val = HMS_detached_step.get_track_val - PROPERTIES_STAR = { - "mass": 20.0, - "log_R": 2.5, - "mdot": -(10.0 ** (-5)), - "state": "PostMS", - "center_he4": 0.8, - "center_h1": 0.0, - "total_moment_of_inertia": 10.0 ** 57, - "log_total_angular_momentum": 52, - "he_core_mass": 7.0, - "surface_he4": 0.2, - "surface_h1": 0.7, - } - m0, t = get_mist0(matching, SingleStar(**PROPERTIES_STAR)) - - self.assertAlmostEqual( - m0, - 23.092430444226363, - places=5, - msg="Initial mass in MIST matching not exactly what expected. Should be 23.092430444226363", - ) - self.assertAlmostEqual( - get_track_val(matching, "mass", m0, t), - 20.000000000003112, - places=5, - msg="Current mass in matching not exactly what expected. Should be 20.000000000003112", - ) - self.assertAlmostEqual( - get_track_val(matching, "log_R", m0, t), - 3.006026700371161, - places=5, - msg="Current log_R in matching not exactly what expected. Should be 3.006026700371161", - ) - self.assertAlmostEqual( - get_track_val(matching, "center_he4", m0, t), - 0.6426242107646186, - places=5, - msg="Current center_he4 in matching not exactly what expected. Should be 0.6426242107646186", - ) - self.assertAlmostEqual( - get_track_val(matching, "he_core_mass", m0, t), - 6.99999999999913, - places=5, - msg="Current mass he_core_mass matching not exactly what expected. Should be 6.99999999999913", - ) - - def test_matching2_minimize(self): - method = "minimize" - matching = HMS_detached_step(PATH_TO_EEPS, matching_method=method, verbose=True) - get_mist0 = HMS_detached_step.get_mist0 - get_track_val = HMS_detached_step.get_track_val - PROPERTIES_STAR = { - "mass": 20.0, - "log_R": 2.5, - "mdot": -(10.0 ** (-5)), - "state": "PostMS", - "center_he4": 0.8, - "center_h1": 0.0, - "total_moment_of_inertia": 10.0 ** 57, - "log_total_angular_momentum": 52, - "he_core_mass": 7.0, - "surface_he4": 0.2, - "surface_h1": 0.7, - } - m0, t = get_mist0(matching, SingleStar(**PROPERTIES_STAR)) - - self.assertAlmostEqual( - m0, - 23.441360274390483, - places=5, - msg="Initial mass in MIST matching not exactly what expected. Should be 23.441360274390483", - ) - self.assertAlmostEqual( - get_track_val(matching, "mass", m0, t), - 21.931950016322705, - places=5, - msg="Current mass in matching not exactly what expected. Should be 21.931950016322705", - ) - self.assertAlmostEqual( - get_track_val(matching, "log_R", m0, t), - 2.5019520817980974, - places=5, - msg="Current log_R in matching not exactly what expected. Should be 2.5019520817980974", - ) - self.assertAlmostEqual( - get_track_val(matching, "center_he4", m0, t), - 0.9095107795447867, - places=5, - msg="Current center_he4 in matching not exactly what expected. Should be 0.9095107795447867", - ) - self.assertAlmostEqual( - get_track_val(matching, "he_core_mass", m0, t), - 6.86182071726521, - places=5, - msg="Current mass he_core_mass matching not exactly what expected. Should be 6.86182071726521", - ) - """ - - -if __name__ == "__main__": - unittest.main() diff --git a/posydon/tests/binary_evol/SN/test_profile_collapse.py b/posydon/tests/binary_evol/SN/test_profile_collapse.py deleted file mode 100644 index eb137cd6fa..0000000000 --- a/posydon/tests/binary_evol/SN/test_profile_collapse.py +++ /dev/null @@ -1,98 +0,0 @@ -import os -import unittest - -import posydon.utils.constants as const -from posydon.binary_evol.singlestar import SingleStar -from posydon.binary_evol.SN.profile_collapse import ( - compute_isco_properties, - do_core_collapse_BH, - get_initial_BH_properties, -) -from posydon.config import PATH_TO_POSYDON -from posydon.grids.psygrid import PSyGrid - -PATH_TO_GRID = os.path.join( - PATH_TO_POSYDON, "posydon/tests/data/POSYDON-UNIT-TESTS/" - "visualization/grid_unit_test_plot.h5") - -if not os.path.isfile(PATH_TO_GRID): - print(PATH_TO_GRID) - raise ValueError("Test grid for unit testing was not found!") - -# constants in CGS -G = const.standard_cgrav -c = const.clight -Mo = const.Msun - - -class TestProfileCollapse(unittest.TestCase): - def test_r_isco(self): - m_BH = 1. * Mo - self.assertAlmostEqual(compute_isco_properties(0., m_BH)[0] / - (G * m_BH / c**2), - 6.0, - places=5) - self.assertAlmostEqual(compute_isco_properties(0.999, m_BH)[0] / - (G * m_BH / c**2), - 1.1817646130335708, - places=5) - - def test_j_isco(self): - m_BH = 1. * Mo - self.assertAlmostEqual(compute_isco_properties(0, m_BH)[1] / - (G * m_BH / c), - 3.464101615137754, - places=5) - self.assertAlmostEqual(compute_isco_properties(0.999, m_BH)[1] / - (G * m_BH / c), - 1.3418378380509774, - places=5) - - def test_radiation_efficiency(self): - m_BH = 1. * Mo - self.assertAlmostEqual((1 - compute_isco_properties(0., m_BH)[2]), - 0.057190958417936644, - places=5) - self.assertAlmostEqual((1 - compute_isco_properties(0.999, m_BH)[2]), - 0.3397940734762088, - places=5) - - def test_low_spinning_He_star(self): - grid = PSyGrid(PATH_TO_GRID) - i = 42 - star = SingleStar(**{'profile': grid[i].final_profile1}) - m_rembar = grid[i].final_values['star_1_mass'] - mass_direct_collapse = 3. # Msun - delta_M = 0.5 # Msun - results = do_core_collapse_BH(star, m_rembar, mass_direct_collapse, - delta_M) - self.assertAlmostEqual(results[0], 13.365071929231409, places=5) - self.assertAlmostEqual(results[1], 8.98074719361575e-09, places=5) - - def test_midly_spinning_He_star(self): - grid = PSyGrid(PATH_TO_GRID) - i = 13 - star = SingleStar(**{'profile': grid[i].final_profile1}) - m_rembar = grid[i].final_values['star_1_mass'] - mass_direct_collapse = 3. # Msun - delta_M = 0.5 # Msun - results = do_core_collapse_BH(star, m_rembar, mass_direct_collapse, - delta_M) - self.assertAlmostEqual(results[0], 5.60832288900688, places=5) - self.assertAlmostEqual(results[1], 0.42583967572001924, places=5) - - def test_rapidly_spinning_He_star(self): - grid = PSyGrid(PATH_TO_GRID) - i = 6 - star = SingleStar(**{'profile': grid[i].final_profile1}) - m_rembar = grid[i].final_values['star_1_mass'] - mass_direct_collapse = 3. # Msun - delta_M = 0.5 # Msun - results = do_core_collapse_BH(star, m_rembar, mass_direct_collapse, - delta_M) - self.assertAlmostEqual(results[0], 38.50844589130613, places=5) - self.assertAlmostEqual(results[1], 0.9835226614001595, places=5) - - -if __name__ == "__main__": - unittest.main() diff --git a/posydon/tests/binary_evol/SN/test_step_SN.py b/posydon/tests/binary_evol/SN/test_step_SN.py deleted file mode 100644 index 065a8d0399..0000000000 --- a/posydon/tests/binary_evol/SN/test_step_SN.py +++ /dev/null @@ -1,567 +0,0 @@ -import os -import unittest - -import matplotlib.cm as cm -import numpy as np -import pandas as pd -from scipy.stats import maxwell - -from posydon.binary_evol.binarystar import BinaryStar -from posydon.binary_evol.singlestar import SingleStar -from posydon.binary_evol.SN.step_SN import StepSN -from posydon.config import PATH_TO_POSYDON - -# github action are not cloning the data submoule, data for unit testing -# are therefore stored to the unit test submodule - -path_to_Sukhbold_datasets = os.path.join( - PATH_TO_POSYDON, "posydon/tests/data/POSYDON-UNIT-TESTS/binary_evol/SN/") - -class TestStepSN(unittest.TestCase): - # TODO - ''' - """ - Test WD formation - """ - def test_WD_formation_RAPID(self): - M_CO = 1.3 - SN = StepSN( **{'mechanism' : 'Fryer+12-rapid', - 'engine' : None, - 'PISN' : 'Marchant+19', - 'ECSN' : 'cosmic', - 'max_neutrino_mass_loss' : 0., - 'kick' : True, - 'sigma_kick_CCSN' : 265.0, - 'sigma_kick_ECSN' : 20.0, - 'max_NS_mass' : 2.5, - 'verbose' : False} ) - - star_prop = {'mass': M_CO / 0.7638113015667961, - 'co_core_mass': M_CO , - 'he_core_mass': M_CO / 0.7638113015667961, - 'state':'stripped_He_Core_C_depleted', - 'profile':None , - 'spin': 0.0} - star = SingleStar(**star_prop) - - star_he_core = star.he_core_mass - - M_rembar = SN.compute_m_rembar(star , None)[0] - - self.assertEqual( SN.SN_type , 'WD') - self.assertEqual( M_rembar , star_he_core) - - def test_WD_formation_DELAYED(self): - M_CO = 1.3 - SN = StepSN( **{'mechanism' : 'Fryer+12-delayed', - 'engine' : None, - 'PISN' : 'Marchant+19', - 'ECSN' : 'cosmic', - 'max_neutrino_mass_loss' : 0., - 'kick' : True, - 'sigma_kick_CCSN' : 265.0, - 'sigma_kick_ECSN' : 20.0, - 'max_NS_mass' : 2.5, - 'verbose' : False} ) - - star_prop = {'mass': M_CO / 0.7638113015667961, - 'co_core_mass': M_CO , - 'he_core_mass': M_CO / 0.7638113015667961, - 'state':'stripped_He_Core_C_depleted', - 'profile':None , - 'spin': 0.0} - star = SingleStar(**star_prop) - - star_he_core = star.he_core_mass - - M_rembar = SN.compute_m_rembar(star , None)[0] - - self.assertEqual( SN.SN_type , 'WD') - self.assertEqual( M_rembar , star_he_core) - - def test_WD_formation_SUKHBOLDN20(self): - M_CO = 1.3 - SN = StepSN( **{'mechanism' : 'Sukhbold+16-engine', - 'engine' : 'N20', - 'PISN' : 'Marchant+19', - 'ECSN' : 'cosmic', - 'max_neutrino_mass_loss' : 0., - 'kick' : True, - 'sigma_kick_CCSN' : 265.0, - 'sigma_kick_ECSN' : 20.0, - 'max_NS_mass' : 2.5, - 'verbose' : False, - 'path_to_datasets': path_to_Sukhbold_datasets} ) - - star_prop = {'mass': M_CO / 0.7638113015667961, - 'co_core_mass': M_CO , - 'he_core_mass': M_CO / 0.7638113015667961, - 'state':'stripped_He_Core_C_depleted', - 'profile':None , - 'spin': 0.0} - star = SingleStar(**star_prop) - - star_he_core = star.he_core_mass - - M_rembar = SN.compute_m_rembar(star , None)[0] - - self.assertEqual( SN.SN_type , 'WD') - self.assertEqual( M_rembar , star_he_core) - - - """ - Test ECSN formation - """ - def test_WD_formation_RAPID(self): - M_CO = 1.38 - SN = StepSN( **{'mechanism' : 'Fryer+12-rapid', - 'engine' : None, - 'PISN' : 'Marchant+19', - 'ECSN' : 'cosmic', - 'max_neutrino_mass_loss' : 0., - 'kick' : True, - 'sigma_kick_CCSN' : 265.0, - 'sigma_kick_ECSN' : 20.0, - 'max_NS_mass' : 2.5, - 'verbose' : False} ) - - star_prop = {'mass': M_CO / 0.7638113015667961, - 'co_core_mass': M_CO , - 'he_core_mass': M_CO / 0.7638113015667961, - 'state':'stripped_He_Core_C_depleted', - 'profile':None , - 'spin': 0.0} - star = SingleStar(**star_prop) - - star_c_core = star.co_core_mass - - M_rembar = SN.compute_m_rembar(star , None)[0] - - self.assertEqual( SN.SN_type , 'ECSN') - self.assertEqual( M_rembar , star_c_core) - - def test_WD_formation_DELAYED(self): - M_CO = 1.38 - SN = StepSN( **{'mechanism' : 'Fryer+12-delayed', - 'engine' : None, - 'PISN' : 'Marchant+19', - 'ECSN' : 'cosmic', - 'max_neutrino_mass_loss' : 0., - 'kick' : True, - 'sigma_kick_CCSN' : 265.0, - 'sigma_kick_ECSN' : 20.0, - 'max_NS_mass' : 2.5, - 'verbose' : False} ) - - star_prop = {'mass': M_CO / 0.7638113015667961, - 'co_core_mass': M_CO , - 'he_core_mass': M_CO / 0.7638113015667961, - 'state':'stripped_He_Core_C_depleted', - 'profile':None , - 'spin': 0.0} - star = SingleStar(**star_prop) - - star_c_core = star.co_core_mass - - M_rembar = SN.compute_m_rembar(star , None)[0] - - self.assertEqual( SN.SN_type , 'ECSN') - self.assertEqual( M_rembar , star_c_core) - - def test_WD_formation_SUKHBOLDN20(self): - M_CO = 1.38 - SN = StepSN( **{'mechanism' : 'Sukhbold+16-engine', - 'engine' : 'N20', - 'PISN' : 'Marchant+19', - 'ECSN' : 'cosmic', - 'max_neutrino_mass_loss' : 0., - 'kick' : True, - 'sigma_kick_CCSN' : 265.0, - 'sigma_kick_ECSN' : 20.0, - 'max_NS_mass' : 2.5, - 'verbose' : False, - 'path_to_datasets': path_to_Sukhbold_datasets} ) - - star_prop = {'mass': M_CO / 0.7638113015667961, - 'co_core_mass': M_CO , - 'he_core_mass': M_CO / 0.7638113015667961, - 'state':'stripped_He_Core_C_depleted', - 'profile':None , - 'spin': 0.0} - star = SingleStar(**star_prop) - - star_c_core = star.co_core_mass - - M_rembar = SN.compute_m_rembar(star , None)[0] - - self.assertEqual( SN.SN_type , 'ECSN') - self.assertEqual( M_rembar , star_c_core) - - - """ - Test CCSN formation - """ - def test_CCSN_formation_RAPID(self): - M_CO = 2.0 - SN = StepSN( **{'mechanism' : 'Fryer+12-rapid', - 'engine' : None, - 'PISN' : 'Marchant+19', - 'ECSN' : 'cosmic', - 'max_neutrino_mass_loss' : 0., - 'kick' : True, - 'sigma_kick_CCSN' : 265.0, - 'sigma_kick_ECSN' : 20.0, - 'max_NS_mass' : 2.5, - 'verbose' : False} ) - - star_prop = {'mass': M_CO / 0.7638113015667961, - 'co_core_mass': M_CO , - 'he_core_mass': M_CO / 0.7638113015667961, - 'state':'stripped_He_Core_C_depleted', - 'profile':None , - 'spin': 0.0} - star = SingleStar(**star_prop) - - star_he_core = star.he_core_mass - - M_rembar = SN.compute_m_rembar(star , None)[0] - - self.assertEqual( SN.SN_type , 'CCSN') - - def test_CCSN_formation_DELAYED(self): - M_CO = 2.0 - SN = StepSN( **{'mechanism' : 'Fryer+12-delayed', - 'engine' : None, - 'PISN' : 'Marchant+19', - 'ECSN' : 'cosmic', - 'max_neutrino_mass_loss' : 0., - 'kick' : True, - 'sigma_kick_CCSN' : 265.0, - 'sigma_kick_ECSN' : 20.0, - 'max_NS_mass' : 2.5, - 'verbose' : False} ) - - star_prop = {'mass': M_CO / 0.7638113015667961, - 'co_core_mass': M_CO , - 'he_core_mass': M_CO / 0.7638113015667961, - 'state':'stripped_He_Core_C_depleted', - 'profile':None , - 'spin': 0.0} - star = SingleStar(**star_prop) - - star_he_core = star.he_core_mass - - M_rembar = SN.compute_m_rembar(star , None)[0] - - self.assertEqual( SN.SN_type , 'CCSN') - - def test_CCSN_formation_SUKHBOLDN20(self): - M_CO = 2.0 - SN = StepSN( **{'mechanism' : 'Sukhbold+16-engine', - 'engine' : 'N20', - 'PISN' : 'Marchant+19', - 'ECSN' : 'cosmic', - 'max_neutrino_mass_loss' : 0., - 'kick' : True, - 'sigma_kick_CCSN' : 265.0, - 'sigma_kick_ECSN' : 20.0, - 'max_NS_mass' : 2.5, - 'verbose' : False, - 'path_to_datasets': path_to_Sukhbold_datasets} ) - - star_prop = {'mass': M_CO / 0.7638113015667961, - 'co_core_mass': M_CO , - 'he_core_mass': M_CO / 0.7638113015667961, - 'state':'stripped_He_Core_C_depleted', - 'profile':None , - 'spin': 0.0} - star = SingleStar(**star_prop) - - star_he_core = star.he_core_mass - - M_rembar = SN.compute_m_rembar(star , None)[0] - - self.assertEqual( SN.SN_type , 'CCSN') - - """ - Test PPISN - """ - def test_remnant_mass_PPISN(self): - M_He = 35.0 - - SN = StepSN( **{'mechanism' : 'Fryer+12-rapid', - 'engine' : None, - 'PISN' : 'Marchant+19', - 'ECSN' : 'cosmic', - 'max_neutrino_mass_loss' : 0., - 'kick' : True, - 'sigma_kick_CCSN' : 265.0, - 'sigma_kick_ECSN' : 20.0, - 'max_NS_mass' : 2.5, - 'verbose' : False} ) - - star_prop = {'mass': M_He, - 'co_core_mass': M_He * 0.7638113015667961 , - 'he_core_mass': M_He , - 'state':'stripped_He_Core_C_depleted', - 'profile':None , - 'spin': 0.0} - - star = SingleStar(**star_prop) - - m_PISN = SN.PISN_prescription(star) - - SN.compute_m_rembar(star , m_PISN)[0] - - self.assertTrue( m_PISN > 0.0 ) - self.assertTrue( m_PISN <= 50.0 ) - self.assertEqual(SN.SN_type , 'PPISN') - - """ - Test PISN - """ - def test_remnant_mass_PPISN(self): - M_He = 70.0 - - SN = StepSN( **{'mechanism' : 'Fryer+12-rapid', - 'engine' : None, - 'PISN' : 'Marchant+19', - 'ECSN' : 'cosmic', - 'max_neutrino_mass_loss' : 0., - 'kick' : True, - 'sigma_kick_CCSN' : 265.0, - 'sigma_kick_ECSN' : 20.0, - 'max_NS_mass' : 2.5, - 'verbose' : False} ) - - star_prop = {'mass': M_He, - 'co_core_mass': M_He * 0.7638113015667961 , - 'he_core_mass': M_He , - 'state':'stripped_He_Core_C_depleted', - 'profile':None , - 'spin': 0.0} - - star = SingleStar(**star_prop) - - m_PISN = SN.PISN_prescription(star) - - SN.compute_m_rembar(star , m_PISN)[0] - - self.assertTrue( np.isnan(m_PISN) ) - self.assertEqual(SN.SN_type , 'PISN') - - """ - Test kick distribution for ECSN - """ - def test_kick_ECSN(self): - SN = StepSN( **{'mechanism' : 'Fryer+12-rapid', - 'engine' : None, - 'PISN' : 'Marchant+19', - 'ECSN' : 'cosmic', - 'max_neutrino_mass_loss' : 0., - 'kick' : True, - 'sigma_kick_CCSN' : 265.0, - 'sigma_kick_ECSN' : 20.0, - 'max_NS_mass' : 2.5, - 'verbose' : False} ) - - SN_type = np.array([]) - Vkick = np.array([]) - M_co = np.full_like(np.arange(50000)*1.0 , 1.38) - - # The He stars are created - for m_co in M_co: - star_prop = {'mass':m_co / 0.7638113015667961, - 'co_core_mass':m_co, - 'he_core_mass':m_co / 0.7638113015667961, - 'state':'stripped_He_Core_C_depleted', - 'profile':None , - 'spin': 0.0} - - - star = SingleStar(**star_prop) - - # The fallback fraction is stracted, this is not a random - # variable then is a fixed value for all explotions - f_fb = SN.compute_m_rembar(star, None)[1] - - # We perform the collapse to extract the SN type of the - # from the code - SN.collapse_star(star) - - if (SN.SN_type == 'CCSN') + (SN.SN_type == 'PPISN') : - kick = SN.generate_kick(star , SN.sigma_kick_CCSN) - sigma = SN.sigma_kick_CCSN - elif SN.SN_type == 'ECSN': - kick = SN.generate_kick(star , SN.sigma_kick_ECSN) - sigma = SN.sigma_kick_ECSN - - - SN_type = np.append(SN_type , SN.SN_type) - Vkick = np.append(Vkick , kick) - - star = None - - dist = (Vkick[SN_type == 'ECSN'] / (1.0 - f_fb)) - - sigma_ECSN = np.round(np.std(dist) / np.sqrt((3*np.pi - 8)/np.pi) , 2) - - print(sigma_ECSN) - - lower = sigma_ECSN <= (SN.sigma_kick_ECSN + 2) - upper = sigma_ECSN >= (SN.sigma_kick_ECSN - 2) - - self.assertTrue( lower ) - self.assertTrue( upper ) - - """ - Test kick distribution for CCSN - """ - def test_kick_CCSN(self): - SN = StepSN( **{'mechanism' : 'Fryer+12-rapid', - 'engine' : None, - 'PISN' : 'Marchant+19', - 'ECSN' : 'cosmic', - 'max_neutrino_mass_loss' : 0., - 'kick' : True, - 'sigma_kick_CCSN' : 265.0, - 'sigma_kick_ECSN' : 20.0, - 'max_NS_mass' : 2.5, - 'verbose' : False} ) - - SN_type = np.array([]) - Vkick = np.array([]) - M_co = np.full_like(np.arange(50000)*1.0 , 8.0) - - # The He stars are created - for m_co in M_co: - star_prop = {'mass':m_co / 0.7638113015667961, - 'co_core_mass':m_co, - 'he_core_mass':m_co / 0.7638113015667961, - 'state':'stripped_He_Core_C_depleted', - 'profile':None , - 'spin': 0.0} - - - star = SingleStar(**star_prop) - - # The fallback fraction is stracted, this is not a random - # variable then is a fixed value for all explotions - f_fb = SN.compute_m_rembar(star, None)[1] - - # We perform the collapse to extract the SN type of the - # from the code - SN.collapse_star(star) - - if (SN.SN_type == 'CCSN') + (SN.SN_type == 'PPISN') : - kick = SN.generate_kick(star , SN.sigma_kick_CCSN) - sigma = SN.sigma_kick_CCSN - elif SN.SN_type == 'ECSN': - kick = SN.generate_kick(star , SN.sigma_kick_ECSN) - sigma = SN.sigma_kick_ECSN - - - SN_type = np.append(SN_type , SN.SN_type) - Vkick = np.append(Vkick , kick) - - star = None - - dist = (Vkick[SN_type == 'CCSN'] / (1.0 - f_fb)) - - sigma_CCSN = np.round(np.std(dist) / np.sqrt((3*np.pi - 8)/np.pi) , 2) - - print(sigma_CCSN) - - lower = sigma_CCSN <= (SN.sigma_kick_CCSN + 2) - upper = sigma_CCSN >= (SN.sigma_kick_CCSN - 2) - - self.assertTrue( lower ) - self.assertTrue( upper ) - - """ - Test generate kick for expanding orbit - """ - def test_generate_kick(self): - SN = StepSN( **{'mechanism' : 'Fryer+12-rapid', - 'engine' : None, - 'PISN' : 'Marchant+19', - 'ECSN' : 'cosmic', - 'max_neutrino_mass_loss' : 0., - 'kick' : True, - 'sigma_kick_CCSN' : 265.0, - 'sigma_kick_ECSN' : 20.0, - 'max_NS_mass' : 2.5, - 'verbose' : False} ) - - fallback = [] - - sep_i = [] - ecc_i = [] - - - sep_f = [] - ecc_f = [] - Vsys_f = [] - - # Loading the test data - def end(binary): - binary.event = 'END' - - properties_star1 = {"mass": 16.200984100257546, "state": "BH", "profile": None} - properties_star2 = {"mass": 5.497560636139926, - "state": "stripped_He_Core_C_depleted", - "profile": None, - 'he_core_mass': 5.497560636139926, - 'co_core_mass': 4.1990989449324205} - - BH = SingleStar(**properties_star1) - He_star = SingleStar(**properties_star2) - properties_binary = { - 'orbital_period' : 6.182118856988261, - 'eccentricity' : 0.0, - 'separation': 39.5265173131476, - 'state' : 'ZAMS', - 'event' : 'CC2', - 'V_sys' : [0, 0, 0], - 'mass_transfer_case' : None, - } - binary = BinaryStar(BH, He_star, **properties_binary) - - pop = [binary] - - - for i in range(len(pop)): - binary = pop[i] - - # We consider that the kicks will have the same direction - # as the velocity of the He star at the periapsis - binary.star_2.natal_kick_array = [None , 0., 0., 0.] - - # We save the orbital separation end eccentricity pre-supernova - sep_i.append( binary.separation ) - ecc_i.append( binary.eccentricity ) - - # We save the fallback fraction f_fb of the remnant - fallback.append(SN.compute_m_rembar(binary.star_2, None)[1]) - - # The orbital kick is applied to the three dimensional orbit - SN.orbital_kick(binary) - - # We save the orbital separation, eccentricity and kick velocity post-supernova - sep_f.append( binary.separation ) - ecc_f.append( binary.eccentricity ) - Vsys_f.append( binary.V_sys ) - - index = [sep_f[i] < sep_i[i] for i in range(len(sep_i))] - - # See if there is any orbit post supernova that shrinked more than one meter - smaller_orbits = np.array(np.array(sep_i)[index] - np.array(sep_f)[index] > 10**-8) - - orbit_comparision = np.sum(smaller_orbits) - - self.assertEqual( orbit_comparision , 0.0) - ''' - - -if __name__ == '__main__': - unittest.main() diff --git a/posydon/tests/binary_evol/test_BinaryStar.py b/posydon/tests/binary_evol/test_BinaryStar.py deleted file mode 100644 index b0f12fe7a7..0000000000 --- a/posydon/tests/binary_evol/test_BinaryStar.py +++ /dev/null @@ -1,127 +0,0 @@ -import os -import unittest - -import numpy as np - -from posydon.binary_evol.binarystar import BinaryStar -from posydon.binary_evol.singlestar import SingleStar -from posydon.config import PATH_TO_POSYDON -from posydon.grids.psygrid import PSyGrid - -PATH_TO_GRID = os.path.join( - PATH_TO_POSYDON, - "posydon/tests/data/POSYDON-UNIT-TESTS/" - "visualization/grid_unit_test_plot.h5" -) - -if not os.path.exists(PATH_TO_GRID): - raise ValueError("Test grid for unit testing was not found!") - - -class TestSingleStar(unittest.TestCase): - def test_BinaryStar_initialisation(self): - # load an example grid: compact object + He-star - grid = PSyGrid(PATH_TO_GRID) - - # initialise a star with the properties of run i=42 - i = 42 - - kwargs1 = { - 'state': 'stripped_He_Core_C_depleted', - 'metallicity': grid[i].initial_values['Z'], - 'mass': grid[i].history1['star_mass'][-1], - 'log_R': np.nan, - 'log_L': grid[i].history1['log_L'][-1], - 'lg_mdot': np.nan, - 'lg_system_mdot' : np.nan, - 'lg_wind_mdot': np.nan, - 'he_core_mass': grid[i].history1['he_core_mass'][-1], - 'he_core_radius': np.nan, - 'c_core_radius': grid[i].history1['he_core_mass'][-1], - 'o_core_mass': np.nan, - 'o_core_radius': np.nan, - 'center_h1': grid[i].history1['center_h1'][-1], - 'center_he4': grid[i].history1['center_he4'][-1], - 'center_c12': grid[i].history1['center_c12'][-1], - 'center_n14': np.nan, - 'center_o16': np.nan, - 'surface_h1': grid[i].history1['surface_h1'][-1], - 'surface_he4': np.nan, - 'surface_c12': np.nan, - 'surface_n14': np.nan, - 'surface_o16': np.nan, - 'log_LH': grid[i].history1['log_LH'][-1], - 'log_LHe': grid[i].history1['log_LHe'][-1], - 'log_LZ': grid[i].history1['log_LZ'][-1], - 'log_Lnuc': grid[i].history1['log_Lnuc'][-1], - 'c12_c12': grid[i].history1['c12_c12'][-1], - 'surf_avg_omega_div_omega_crit': np.nan, - 'total_moment_of_inertia': np.nan, - 'log_total_angular_momentum': np.nan, - 'spin': np.nan, - 'profile': grid[i].final_profile1 - } - - star_1 = SingleStar(**kwargs1) - - kwargs2 = { - 'state': 'stripped_He_Core_C_depleted', - 'metallicity': grid[i].initial_values['Z'], - 'mass': grid[i].initial_values['star_2_mass'], - 'log_R': np.nan, - 'log_L': np.nan, - 'lg_mdot': np.nan, - 'lg_system_mdot' : np.nan, - 'lg_wind_mdot': np.nan, - 'he_core_mass': np.nan, - 'he_core_radius': np.nan, - 'c_core_radius': np.nan, - 'o_core_mass': np.nan, - 'o_core_radius': np.nan, - 'center_h1': np.nan, - 'center_he4': np.nan, - 'center_c12': np.nan, - 'center_n14': np.nan, - 'center_o16': np.nan, - 'surface_h1': np.nan, - 'surface_he4': np.nan, - 'surface_c12': np.nan, - 'surface_n14': np.nan, - 'surface_o16': np.nan, - 'log_LH': np.nan, - 'log_LHe': np.nan, - 'log_LZ': np.nan, - 'log_Lnuc': np.nan, - 'c12_c12': np.nan, - 'surf_avg_omega_div_omega_crit': np.nan, - 'total_moment_of_inertia': np.nan, - 'log_total_angular_momentum': np.nan, - 'spin': np.nan, - 'profile': None - } - - star_2 = SingleStar(**kwargs2) - - kwargs3 = { - 'state': 'detached', - 'event': 'CC1', - 'time': grid.final_values['age'][i], - 'orbital_period': grid.final_values['period_days'][i], - 'eccentricity': 0., - 'separation': grid.final_values['binary_separation'][i], - 'V_sys': [0, 0, 0], - 'rl_relative_overflow_1' : np.nan, - 'rl_relative_overflow_2' : np.nan, - 'lg_mtransfer_rate': np.nan, - #'mass_transfer_case': None - } - - binary = BinaryStar(star_1, star_2, **kwargs3) - - # check that the above kwars have a history - for item in kwargs3.keys(): - self.assertIsInstance(getattr(binary, item + '_history'), list) - - -if __name__ == '__main__': - unittest.main() diff --git a/posydon/tests/binary_evol/test_SingleStar.py b/posydon/tests/binary_evol/test_SingleStar.py deleted file mode 100644 index 4fabf41981..0000000000 --- a/posydon/tests/binary_evol/test_SingleStar.py +++ /dev/null @@ -1,73 +0,0 @@ -import os -import unittest - -import numpy as np - -from posydon.binary_evol.singlestar import SingleStar -from posydon.config import PATH_TO_POSYDON -from posydon.grids.psygrid import PSyGrid - -PATH_TO_GRID = os.path.join( - PATH_TO_POSYDON, - "posydon/tests/data/POSYDON-UNIT-TESTS/" - "visualization/grid_unit_test_plot.h5" -) - -if not os.path.exists(PATH_TO_GRID): - raise ValueError("Test grid for unit testing was not found!") - - -class TestSingleStar(unittest.TestCase): - def test_SingleStar_initialisation(self): - # load an example grid: compact object + He-star - grid = PSyGrid(PATH_TO_GRID) - - # initialise a star with the properties of run i=42 - i = 42 - - # all STARPROPERTIES - kwargs = { - 'state': 'stripped_He_Central_C_depletion', - 'metallicity': grid[i].initial_values['Z'], - 'mass': grid[i].history1['star_mass'][-1], - 'log_R': np.nan, - 'log_L': grid[i].history1['log_L'][-1], - 'lg_mdot': np.nan, - 'lg_system_mdot': np.nan, - 'lg_wind_mdot': np.nan, - 'he_core_mass': grid[i].history1['he_core_mass'][-1], - 'he_core_radius': np.nan, - 'c_core_radius': grid[i].history1['he_core_mass'][-1], - 'o_core_mass': np.nan, - 'o_core_radius': np.nan, - 'center_h1': grid[i].history1['center_h1'][-1], - 'center_he4': grid[i].history1['center_he4'][-1], - 'center_c12': grid[i].history1['center_c12'][-1], - 'center_n14': np.nan, - 'center_o16': np.nan, - 'surface_h1': grid[i].history1['surface_h1'][-1], - 'surface_he4': np.nan, - 'surface_c12': np.nan, - 'surface_n14': np.nan, - 'surface_o16': np.nan, - 'log_LH': grid[i].history1['log_LH'][-1], - 'log_LHe': grid[i].history1['log_LHe'][-1], - 'log_LZ': grid[i].history1['log_LZ'][-1], - 'log_Lnuc': grid[i].history1['log_Lnuc'][-1], - 'c12_c12': grid[i].history1['c12_c12'][-1], - 'surf_avg_omega_div_omega_crit': np.nan, - 'total_moment_of_inertia': np.nan, - 'log_total_angular_momentum': np.nan, - 'spin': np.nan, - 'profile': grid[i].final_profile1 - } - - star = SingleStar(**kwargs) - - # check that the above kwars have a history - for item in kwargs.keys(): - self.assertIsInstance(getattr(star, item + '_history'), list) - - -if __name__ == '__main__': - unittest.main() diff --git a/posydon/tests/data/POSYDON-UNIT-TESTS b/posydon/tests/data/POSYDON-UNIT-TESTS deleted file mode 160000 index eaf9d59229..0000000000 --- a/posydon/tests/data/POSYDON-UNIT-TESTS +++ /dev/null @@ -1 +0,0 @@ -Subproject commit eaf9d592291f093cc0095e13c93d431c5b6051da diff --git a/posydon/tests/interpolation/test_data_scaling.py b/posydon/tests/interpolation/test_data_scaling.py deleted file mode 100644 index ba9e29fd73..0000000000 --- a/posydon/tests/interpolation/test_data_scaling.py +++ /dev/null @@ -1,251 +0,0 @@ -from unittest import TestCase - -import numpy as np - -from posydon.interpolation.data_scaling import DataScaler - - -class DataScaler_test(TestCase): - def setUp(self): - self.sc = DataScaler() - self.x = np.array([1,2,3,4]) - self.y = -self.x.copy() - - def test_fit(self): - # not a 1D array - with self.assertRaises(AssertionError): - self.sc.fit([12,2,3]) #list - with self.assertRaises(AssertionError): - self.sc.fit(np.ones((5,1))) # list - # default value 'none' - with self.subTest(i=0): - self.sc.fit(self.x) - self.assertIsInstance(self.sc.params, list) - self.assertEqual(self.sc.method,'none') - self.assertEqual(len(self.sc.params),0) - # min_max - with self.subTest(i=1): - self.sc.fit(self.x, method='min_max') - self.assertEqual(self.sc.method, 'min_max') - self.assertEqual(len(self.sc.params),2) - self.assertEqual(self.sc.params[0], 1) - self.assertEqual(self.sc.params[1], 4) - self.assertEqual(self.sc.lower, -1) - self.assertEqual(self.sc.upper, 1) - # min_max modifying lower/upper - with self.subTest(i=2): - with self.assertRaises(AssertionError): - self.sc.fit(self.x, method='min_max', lower=2) - self.sc.fit(self.x, method='min_max', lower=-2, upper=0.5) - self.assertEqual(self.sc.params[0], 1) - self.assertEqual(self.sc.params[1], 4) - self.assertEqual(self.sc.lower, -2) - self.assertEqual(self.sc.upper, 0.5) - # max_abs - with self.subTest(i=3): - self.sc.fit(self.x, method='max_abs') - self.assertEqual(self.sc.method, 'max_abs') - self.assertEqual(len(self.sc.params), 1) - self.assertEqual(self.sc.params[0], 4) - self.sc.fit(self.y, method='max_abs') # check with negative numbers - self.assertEqual(self.sc.params[0], 4) - # standarize - with self.subTest(i=4): - self.sc.fit(self.x, method='standarize') - self.assertEqual(self.sc.method, 'standarize') - self.assertEqual(len(self.sc.params), 2) - self.assertEqual(self.sc.params[0], np.mean(self.x)) - self.assertEqual(self.sc.params[1], np.std(self.x)) - # log_min_max - with self.subTest(i=5): - self.sc.fit(self.x, method='log_min_max') - self.assertEqual(self.sc.method, 'log_min_max') - self.assertEqual(len(self.sc.params), 2) - self.assertEqual(self.sc.params[0], 0) - self.assertEqual(self.sc.params[1], np.log10(4)) - self.assertEqual(self.sc.lower, -1) - self.assertEqual(self.sc.upper, 1) - # log_min_max modifying lower/upper - with self.subTest(i=6): - with self.assertRaises(AssertionError): - self.sc.fit(self.x, method='log_min_max', lower=2) - self.sc.fit(self.x, method='log_min_max', lower=-2, upper=0.5) - self.assertEqual(self.sc.params[0], 0) - self.assertEqual(self.sc.params[1], np.log10(4)) - self.assertEqual(self.sc.lower, -2) - self.assertEqual(self.sc.upper, 0.5) - # log_max_abs - with self.subTest(i=7): - self.sc.fit(self.x, method='log_max_abs') - self.assertEqual(self.sc.method, 'log_max_abs') - self.assertEqual(len(self.sc.params), 1) - self.assertEqual(self.sc.params[0], np.log10(4)) - self.sc.fit(self.y, method='log_max_abs') # check with negative numbers - self.assertTrue(np.isnan(self.sc.params[0])) - # log_standarize - with self.subTest(i=8): - self.sc.fit(self.x, method='log_standarize') - self.assertEqual(self.sc.method, 'log_standarize') - self.assertEqual(len(self.sc.params), 2) - self.assertEqual(self.sc.params[0], np.mean(np.log10(self.x))) - self.assertEqual(self.sc.params[1], np.std(np.log10(self.x))) - # wrong method string - with self.assertRaises(ValueError): - self.sc.fit(self.x, method='wrong') - - def test_transform(self): - # check .fit has been run first - with self.assertRaises(AssertionError): - sc = DataScaler() - sc.transform(self.x) - # default value 'none' - with self.subTest(i=0): - self.sc.fit(self.x) - xt = self.sc.transform(self.x) - self.assertIsInstance(xt, np.ndarray) - self.assertEqual(len(xt.shape),1) - self.assertEqual(np.sum(np.abs(xt-self.x)),0) - # min_max - with self.subTest(i=1): - self.sc.fit(self.x, method='min_max') - xt = self.sc.transform(self.x) - self.assertAlmostEqual(xt.min(),self.sc.lower) - self.assertAlmostEqual(xt.max(), self.sc.upper) - # min_max modifying lower/upper - with self.subTest(i=2): - self.sc.fit(self.x, method='min_max', lower=-2, upper=0.5) - xt = self.sc.transform(self.x) - self.assertAlmostEqual(xt.min(), self.sc.lower) - self.assertAlmostEqual(xt.max(), self.sc.upper) - # max_abs - with self.subTest(i=3): - self.sc.fit(self.x, method='max_abs') - xt = self.sc.transform(self.x) - self.assertEqual(np.abs(xt).max(), 1) - self.assertGreaterEqual(np.abs(xt).min(),-1) - self.sc.fit(self.y, method='max_abs') # check with negative numbers - xt = self.sc.transform(self.x) - self.assertEqual(np.abs(xt).max(), 1) - self.assertGreaterEqual(np.abs(xt).min(), -1) - # standarize - with self.subTest(i=4): - self.sc.fit(self.x, method='standarize') - xt = self.sc.transform(self.x) - self.assertAlmostEqual(xt.mean(),0) - self.assertAlmostEqual(xt.std(), 1) - # log_min_max - with self.subTest(i=5): - self.sc.fit(self.x, method='log_min_max') - xt = self.sc.transform(self.x) - self.assertAlmostEqual(xt.min(), self.sc.lower) - self.assertAlmostEqual(xt.max(), self.sc.upper) - # log_min_max modifying lower/upper - with self.subTest(i=6): - self.sc.fit(self.x, method='log_min_max', lower=-2, upper=0.5) - xt = self.sc.transform(self.x) - self.assertAlmostEqual(xt.min(), self.sc.lower) - self.assertAlmostEqual(xt.max(), self.sc.upper) - # log_max_abs - with self.subTest(i=7): - self.sc.fit(self.x, method='log_max_abs') - xt = self.sc.transform(self.x) - self.assertEqual(np.abs(xt).max(), 1) - self.assertGreaterEqual(np.abs(xt).min(), -1) - # log_standarize - with self.subTest(i=8): - self.sc.fit(self.x, method='log_standarize') - xt = self.sc.transform(self.x) - self.assertAlmostEqual(xt.mean(), 0) - self.assertAlmostEqual(xt.std(), 1) - - def test_fit_and_transform(self): - - # default value 'none' - with self.subTest(i=0): - xt = self.sc.fit_and_transform(self.x) - self.assertIsInstance(xt, np.ndarray) - self.assertEqual(len(xt.shape),1) - self.assertEqual(np.sum(np.abs(xt-self.x)),0) - # min_max - with self.subTest(i=1): - xt = self.sc.fit_and_transform(self.x, method='min_max') - self.assertAlmostEqual(xt.min(),self.sc.lower) - self.assertAlmostEqual(xt.max(), self.sc.upper) - # min_max modifying lower/upper - with self.subTest(i=2): - xt = self.sc.fit_and_transform(self.x, method='min_max', lower=-2, upper=0.5) - self.assertAlmostEqual(xt.min(), self.sc.lower) - self.assertAlmostEqual(xt.max(), self.sc.upper) - # max_abs - with self.subTest(i=3): - xt = self.sc.fit_and_transform(self.x, method='max_abs') - self.assertEqual(np.abs(xt).max(), 1) - self.assertGreaterEqual(np.abs(xt).min(),-1) - xt = self.sc.fit_and_transform(self.y, method='max_abs') # check with negative numbers - self.assertEqual(np.abs(xt).max(), 1) - self.assertGreaterEqual(np.abs(xt).min(), -1) - # standarize - with self.subTest(i=4): - xt = self.sc.fit_and_transform(self.x, method='standarize') - self.assertAlmostEqual(xt.mean(),0) - self.assertAlmostEqual(xt.std(), 1) - # log_min_max - with self.subTest(i=5): - xt = self.sc.fit_and_transform(self.x, method='log_min_max') - self.assertAlmostEqual(xt.min(), self.sc.lower) - self.assertAlmostEqual(xt.max(), self.sc.upper) - # log_min_max modifying lower/upper - with self.subTest(i=6): - xt = self.sc.fit_and_transform(self.x, method='log_min_max', lower=-2, upper=0.5) - self.assertAlmostEqual(xt.min(), self.sc.lower) - self.assertAlmostEqual(xt.max(), self.sc.upper) - # log_max_abs - with self.subTest(i=7): - xt = self.sc.fit_and_transform(self.x, method='log_max_abs') - self.assertEqual(np.abs(xt).max(), 1) - self.assertGreaterEqual(np.abs(xt).min(), -1) - # log_standarize - with self.subTest(i=8): - xt = self.sc.fit_and_transform(self.x, method='log_standarize') - self.assertAlmostEqual(xt.mean(), 0) - self.assertAlmostEqual(xt.std(), 1) - - def test_inv_transform(self): - # default value 'none' - with self.subTest(i=0): - xt = self.sc.fit_and_transform(self.x) - self.assertAlmostEqual(np.sum(np.abs(self.sc.inv_transform(xt) - self.x)), 0) - # min_max - with self.subTest(i=1): - xt = self.sc.fit_and_transform(self.x, method='min_max') - self.assertAlmostEqual(np.sum(np.abs(self.sc.inv_transform(xt) - self.x)), 0) - # min_max modifying lower/upper - with self.subTest(i=2): - xt = self.sc.fit_and_transform(self.x, method='min_max', lower=-2, upper=0.5) - self.assertAlmostEqual(np.sum(np.abs(self.sc.inv_transform(xt) - self.x)), 0) - # max_abs - with self.subTest(i=3): - xt = self.sc.fit_and_transform(self.x, method='max_abs') - self.assertAlmostEqual(np.sum(np.abs(self.sc.inv_transform(xt) - self.x)), 0) - xt = self.sc.fit_and_transform(self.y, method='max_abs') # check with negative numbers - self.assertAlmostEqual(np.sum(np.abs(self.sc.inv_transform(xt) - self.y)), 0) - # standarize - with self.subTest(i=4): - xt = self.sc.fit_and_transform(self.x, method='standarize') - self.assertAlmostEqual(np.sum(np.abs(self.sc.inv_transform(xt) - self.x)), 0) - # log_min_max - with self.subTest(i=5): - xt = self.sc.fit_and_transform(self.x, method='log_min_max') - self.assertAlmostEqual(np.sum(np.abs(self.sc.inv_transform(xt) - self.x)), 0) - # log_min_max modifying lower/upper - with self.subTest(i=6): - xt = self.sc.fit_and_transform(self.x, method='log_min_max', lower=-2, upper=0.5) - self.assertAlmostEqual(np.sum(np.abs(self.sc.inv_transform(xt) - self.x)), 0) - # log_max_abs - with self.subTest(i=7): - xt = self.sc.fit_and_transform(self.x, method='log_max_abs') - self.assertAlmostEqual(np.sum(np.abs(self.sc.inv_transform(xt) - self.x)), 0) - # log_standarize - with self.subTest(i=8): - xt = self.sc.fit_and_transform(self.x, method='log_standarize') - self.assertAlmostEqual(np.sum(np.abs(self.sc.inv_transform(xt) - self.x)), 0) diff --git a/posydon/tests/interpolation/test_interpolation.py b/posydon/tests/interpolation/test_interpolation.py deleted file mode 100644 index 3e791481bb..0000000000 --- a/posydon/tests/interpolation/test_interpolation.py +++ /dev/null @@ -1,37 +0,0 @@ -# from unittest import TestCase -# -# import numpy as np -# import posydon.grids.psygrid as psg -# import posydon.interpolation.interpolation as psi -# -# try: -# import gpflow -# except ImportError: -# print("Import Error for TensorFlow and/or GPFlow, most, if not all " -# "features of the psyInterp class will not work, please check your installation " -# "of gpflow or tensorflow or install the correct gpflow by running pip install .[ml]") -# -# -# class Interpolation_test(TestCase): -# def setUp(self): -# # FIX PATH YOU CANNOT GIVE A LOCAL PATH -# self.grid = psg.PSyGrid() -# self.grid.load("/home/juanga/Desktop/data/grid_BH_He_star.h5") -# self.input_keys = self.grid.initial_values.dtype.names -# self.output_keys = self.grid.final_values.dtype.names[2:4] -# self.input_norms = ['log_min_max', 'log_min_max', 'log_min_max'] -# self.output_norms = ['log_standarize', 'log_standarize'] -# -# def test_init(self): -# m = psi.psyInterp(grid=self.grid, -# in_keys=self.input_keys, -# out_keys=self.output_keys, -# in_scaling=self.input_norms, -# out_scaling=self.output_norms) -# self.assertEqual(len(m.in_keys), len(self.input_keys)) -# self.assertEqual(len(m.out_keys), len(self.output_keys)) -# self.assertEqual(m.XYT.shape[0], m.N) -# self.assertEqual(m.XYT.shape[1], m.n_in+m.n_out) -# -# class SGPInterp_test(TestCase): -# pass diff --git a/posydon/tests/popsyn/test_binarypopulation.py b/posydon/tests/popsyn/test_binarypopulation.py deleted file mode 100644 index 8388acb69a..0000000000 --- a/posydon/tests/popsyn/test_binarypopulation.py +++ /dev/null @@ -1,180 +0,0 @@ -import os -import unittest - -import matplotlib.pyplot as plt -import numpy as np - -from posydon.binary_evol.flow_chart import flow_chart -from posydon.binary_evol.simulationproperties import SimulationProperties -from posydon.binary_evol.step_end import step_end -from posydon.popsyn.binarypopulation import BinaryPopulation - - -class TestBinaryPopulation(unittest.TestCase): - @classmethod - def setUpClass(cls): - np.random.seed(12345) - cls.POP_KWARGS = { - "number_of_binaries": int(500), - "primary_mass_min": 15 - } - - class MyEndStep(step_end): - def __call__(self, binary): - step_end.__call__(binary) - binary.star_1.mass = np.sqrt(binary.star_1.mass) - # change star_1 mass - - cls.SIM_PROP = SimulationProperties( - flow = (flow_chart, {}), - step_HMS_HMS = (MyEndStep, {}), - step_CO_HeMS = (MyEndStep, {}), - step_CO_HMS_RLO = (MyEndStep, {}), - step_detached = (MyEndStep, {}), - step_CE = (MyEndStep, {}), - step_SN = (MyEndStep, {}), - step_end = (MyEndStep, {}), - ) - - def test_init_0(self): - bin_pop = BinaryPopulation() - self.assertTrue(isinstance(bin_pop, BinaryPopulation)) - self.assertTrue(hasattr(bin_pop, 'population_properties')) - self.assertTrue(hasattr(bin_pop, 'entropy')) - - - def test_generate(self): - bin_pop = BinaryPopulation(**self.POP_KWARGS) - for i in range(bin_pop.number_of_binaries): - bin_pop.manager.generate(**self.POP_KWARGS) - self.assertTrue(len(bin_pop) == self.POP_KWARGS["number_of_binaries"]) - - self.assertTrue( [b.star_1.mass > self.POP_KWARGS["primary_mass_min"] - for b in bin_pop] ) - - # def test_generate_initial_binaries(self): - # bin_pop = BinaryPopulation(generate_initial_population=False, **self.POP_KWARGS) - # bin_pop.generate_initial_binaries() - # first_bin = bin_pop[1] - # bin_pop.generate_initial_binaries(overwrite=True) - # second_bin = bin_pop[1] - # self.assertFalse( - # first_bin is second_bin, msg="Binaries should not be the same object." - # ) - # - # def test_gen_init_bin_err(self): - # bin_pop = BinaryPopulation() - # with self.assertRaisesRegex( - # ValueError, "set overwrite=True to overwrite existing population" - # ): - # bin_pop.generate_initial_binaries() - # - def test_sim_properties(self): - bin_pop = BinaryPopulation(population_properties=self.SIM_PROP) - self.assertTrue(isinstance(bin_pop, BinaryPopulation)) - self.assertTrue(bin_pop.population_properties is self.SIM_PROP) - bin_pop.population_properties.load_steps() - return bin_pop - - # def test_evolve(self): - # bin_pop = self.test_sim_properties() - # test_ids = np.arange(0, 15, 1) - # original_bins = bin_pop.copy(ids=test_ids) - # bin_pop.evolve() - # for b in bin_pop: - # with self.subTest("Check event END", binary_ind=b.index): - # self.assertTrue(b.event == "END") - # - # for j, b in enumerate(bin_pop[test_ids]): - # with self.subTest( - # "Check mass changed", - # binary_ind=b.index, - # test_ind=original_bins[j].index, - # ): - # self.assertAlmostEqual( - # b.star_1.mass, np.sqrt(original_bins[j].star_1.mass), places=8 - # ) - - # def test_evolve_binary_population(self): - # # It is unclear how to test multiprocessing at the moment - # POP_KWARGS = {"number_of_binaries": int(500), "primary_mass_min": 15} - # - # def end(binary): - # binary.star_1.mass = np.sqrt(binary.star_1.mass) - # binary.event = "END" - # - # def get_sim_prop(): - # SIM_PROP = SimulationProperties( - # flow={("H-rich_Core_H_burning", "H-rich_Core_H_burning", - # "detached", "ZAMS"): "step_end"}, step_end=end, max_simulation_time=13.7e9) - # return SIM_PROP - # - # bin_pop = BinaryPopulation(population_properties=get_sim_prop, **POP_KWARGS) - # bin_pop.evolve_binary_population(num_batches=4, verbose=True, use_df=True) - - # def test_evolve_each_binary(self): - # bin_pop = self.test_sim_properties() - # for num, evolved_bin in enumerate(bin_pop.evolve_each_binary()): - # with self.subTest("Evolve generator", num=num): - # self.assertTrue(num == evolved_bin.index) - # self.assertTrue(evolved_bin.event == "END") - - # def test_copy(self): - # bin_pop = self.test_sim_properties() - # binary_copy = bin_pop.copy(ids=0) - # self.assertFalse(binary_copy is bin_pop[0]) - # all_binaries_copy = bin_pop.copy() - # self.assertFalse( - # any([copy_b is b for copy_b, b in zip(all_binaries_copy, bin_pop)]) - # ) - - # TODO: step_times is breaking to_df with only initialized binary / pop - # def test_to_df(self): - # bin_pop = self.test_sim_properties() - # self.assertTrue( isinstance(bin_pop.to_df(), pd.DataFrame) ) - - # def test_get_bin_by_index(self): - # bin_pop = self.test_sim_properties() - # test_indicies = [1, 6, 9, 8, 2] - # out_bins = bin_pop.get_binaries_by_index(test_indicies) - # self.assertTrue([b.index for b in out_bins] == test_indicies) - - # def test_bool_and_len(self): - # bin_pop = BinaryPopulation(population_properties=self.SIM_PROP) - # self.assertTrue(bool(bin_pop), msg="True if len self > 0") - # self.assertTrue( - # len(bin_pop) == bin_pop.number_of_binaries, msg="Should be len __binaries" - # ) - - # def test_get_subpopulation(self): - # bin_pop = BinaryPopulation(population_properties=self.SIM_PROP, **self.POP_KWARGS) - # for i in range(200, 300): - # bin_pop[i].star_2.state = "BH" - # subpop = bin_pop.get_subpopulation(star_1_states=None, star_2_states="BH") - # self.assertTrue(all([bin.index == 200 + j for j, bin in enumerate(subpop)])) - # self.assertTrue(all([bin.star_2.state == "BH" for bin in subpop])) - - # def test_pickle_and_load(self): - # bin_pop = BinaryPopulation() - # bin_pop.pickle("saved_population.pkl") - # self.assertTrue(os.path.isfile("saved_population.pkl")) - # - # loaded_pop = BinaryPopulation.load("saved_population.pkl") - # self.assertTrue(isinstance(loaded_pop, BinaryPopulation)) - # - # def test_unique_sim_prop(self): - # bin_pop = BinaryPopulation() - # prop = bin_pop.population_properties - # self.assertTrue( - # all([prop is b.properties for b in bin_pop]), - # msg="All binary properties should map to the same object.", - # ) - - def tearDown(self): - # remove pickled files - if os.path.isfile("saved_population.pkl"): - os.remove("saved_population.pkl") - - -if __name__ == "__main__": - unittest.main() diff --git a/posydon/tests/popsyn/test_synthetic_population.py b/posydon/tests/popsyn/test_synthetic_population.py deleted file mode 100644 index 58fc01a033..0000000000 --- a/posydon/tests/popsyn/test_synthetic_population.py +++ /dev/null @@ -1,645 +0,0 @@ -import os -import tempfile - -import numpy as np -import pandas as pd -import pytest - -from posydon.config import PATH_TO_POSYDON -from posydon.popsyn.synthetic_population import ( - History, - Oneline, - Population, - PopulationIO, - PopulationRunner, - parameter_array, -) -from posydon.utils.constants import Zsun - - -# Test the PopulationRunner class -class TestPopulationRunner: - # Test the initialisation of the PopulationRunner class - def test_init(self): - # Test the initialisation of the PopulationRunner class - poprun = PopulationRunner(PATH_TO_POSYDON+'posydon/popsyn/population_params_default.ini', verbose=True) - - # Check if the verbose attribute is set correctly - assert poprun.verbose == True, 'Verbose attribute is not set correctly' - - # Check if the solar_metallicities attribute is a list - assert isinstance(poprun.solar_metallicities, list), 'solar_metallicities attribute is not a list' - - # Check if the binary_populations attribute is a list - assert isinstance(poprun.binary_populations, list), 'binary_populations attribute is not a list' - - def test_init_invalid_ini_file(self): - with pytest.raises(ValueError): - PopulationRunner('invalid_file') - - def test_single_metallicity(self): - # copy the default ini file to a new file - new_ini_file = 'test_population_params.ini' - with open(PATH_TO_POSYDON+'posydon/popsyn/population_params_default.ini', 'r') as file: - data = file.read() - start = data.find('metallicity') - replace_str = 'metallicity = 0.0001' - data_new = data[:start] + replace_str + data[start+22:] - with open(new_ini_file, 'w') as file: - file.write(data_new) - - poprun = PopulationRunner(new_ini_file) - assert poprun.binary_populations[0].metallicity == 0.0001 - - def test_evolve(self, mocker): - mocker.patch('posydon.popsyn.binarypopulation.BinaryPopulation.evolve', return_value=None) - mocker.patch('posydon.popsyn.binarypopulation.BinaryPopulation.combine_saved_files', return_value=None) - - poprun = PopulationRunner(PATH_TO_POSYDON+'/posydon/popsyn/population_params_default.ini') - # set population to 1 binary - for pop in poprun.binary_populations: - pop.number_of_systems = 1 - - # create a temporary directory with 1e-04_Zsun_batches - os.makedirs('1e-04_Zsun_batches', exist_ok=True) - - poprun.evolve() - assert poprun.binary_populations, 'binary_populations attribute is empty after calling the evolve method' - - - def test_evolve_file_exists(self, mocker): - mocker.patch('posydon.popsyn.binarypopulation.BinaryPopulation.evolve', return_value=None) - mocker.patch('posydon.popsyn.binarypopulation.BinaryPopulation.combine_saved_files', return_value=None) - - # Create a temporary file with the 1e-04_ZSun_population.h5 name - open('1e-04_Zsun_population.h5', 'w').close() - - poprun = PopulationRunner(PATH_TO_POSYDON+'/posydon/popsyn/population_params_default.ini') - # set population to 1 binary - for pop in poprun.binary_populations: - pop.number_of_systems = 1 - - with pytest.raises(FileExistsError): - poprun.evolve() - - - # Test the evolve method - def test_changed_binarypop(self): - poprun = PopulationRunner(PATH_TO_POSYDON+'/posydon/popsyn/population_params_default.ini') - # test - assert poprun.binary_populations[0].metallicity == 0.0001 - # Check if the temp_directory attribute is set correctly - assert poprun.binary_populations[0].kwargs['temp_directory'] == '1e-04_Zsun_batches', 'temp_directory attribute is not set correctly' - - # Check if the binary_populations attribute is not empty after calling the evolve method - assert poprun.binary_populations, 'binary_populations attribute is empty after calling the evolve method' - - @classmethod - def teardown_class(cls): - if os.path.exists('1e-04_Zsun_batches'): - os.rmdir('1e-04_Zsun_batches') - if os.path.exists('test_population_params.ini'): - os.remove('test_population_params.ini') - if os.path.exists('1e-04_Zsun_population.h5'): - os.remove('1e-04_Zsun_population.h5') - - - -# Test the History class -class TestHistory: - - @classmethod - def setup_class(cls): - # Set up a test HDF5 file using pandas HDFStore - cls.filename = 'test_population.h5' - with pd.HDFStore(cls.filename, 'w') as store: - # Create a history dataframe - history_data = pd.DataFrame({'time': [1, 2, 3], 'event': ['ZAMS','oRLO1', 'CEE']}) - store.append('history',history_data, data_columns=True) - - cls.filename2 = 'test_population2.h5' - with pd.HDFStore(cls.filename2, 'w') as store: - # Create a history dataframe - history_data = pd.DataFrame({'time': [1, 2, 3], 'event': ['ZAMS','oRLO1', 'CEE']}) - store.append('history',history_data, data_columns=True) - - @classmethod - def teardown_class(cls): - os.remove(cls.filename) - - def setup_method(self): - self.history = History(self.filename, verbose=False, chunksize=10000) - - def test_init(self): - history = History(self.filename2, verbose=True, chunksize=10000) - assert history.filename == self.filename2, 'Filename is not set correctly' - assert history.verbose == True, 'Verbose attribute is not set correctly' - assert history.chunksize == 10000, 'Chunksize attribute is not set correctly' - - expected_lengths = pd.DataFrame(index=[0, 1, 2],data={'index': [1, 1, 1]}) - expected_lengths.index.name = 'index' - pd.testing.assert_frame_equal(history.lengths, expected_lengths, 'Lengths attribute is not equal to the expected dataframe') - - assert history.number_of_systems == 3, 'Number of systems attribute is not None' - assert history.columns.to_list() == ['time', 'event'], 'Columns attribute is not None' - - assert isinstance(history.indices, np.ndarray), 'Indices attribute is not an ndarray' - np.testing.assert_array_equal(history.indices, np.array([0, 1, 2]), 'Indices attribute is not equal to the expected list') - - with pytest.raises(FileNotFoundError): - History('invalid_filename.h5', verbose=False, chunksize=10000) - - def test_init_verbose_true(self): - history = History(self.filename, chunksize=10000) - assert history.verbose == False, 'Verbose attribute is not set correctly' - - - def test_getitem_single_index(self): - df = self.history[0] - assert isinstance(df, pd.DataFrame), 'Returned object is not a DataFrame' - assert len(df) == 1, 'Returned DataFrame does not have the correct length' - - - def test_getitem_multiple_indices(self): - df = self.history[[0, 1, 2]] - assert isinstance(df, pd.DataFrame), 'Returned object is not a DataFrame' - assert len(df) == 3, 'Returned DataFrame does not have the correct length' - - def test_getitem_index_array(self): - indices = np.array([0, 1, 2]) - df = self.history[indices] - assert isinstance(df, pd.DataFrame), 'Returned object is not a DataFrame' - assert len(df) == 3, 'Returned DataFrame does not have the correct length' - - def test_getitem_single_column(self): - column = 'time' - df = self.history[column] - assert isinstance(df, pd.DataFrame), 'Returned object is not a DataFrame' - assert len(df.columns) == 1, 'Returned DataFrame does not have the correct number of columns' - - def test_getitem_invalid_column(self): - column = 'invalid_column' - with pytest.raises(ValueError): - self.history[column] - - def test_getitem_invalid_keys(self): - columns = ['time', 'invalid_column'] - with pytest.raises(ValueError): - self.history[columns] - - def test_getitem_boolean_mask_numpy(self): - mask = (self.history['time'] > 1).to_numpy() - df = self.history[mask] - assert isinstance(df, pd.DataFrame), 'Returned object is not a DataFrame' - - def test_getitem_boolean_mask_pandas(self): - mask = self.history['time'] > 1 - df = self.history[mask] - assert isinstance(df, pd.DataFrame), 'Returned object is not a DataFrame' - - def test_getitem_multiple_columns(self): - columns = ['time', 'event'] - df = self.history[columns] - assert isinstance(df, pd.DataFrame), 'Returned object is not a DataFrame' - assert len(df.columns) == 2, 'Returned DataFrame does not have the correct number of columns' - - def test_getitem_invalid_key(self): - with pytest.raises(ValueError): - self.history[{1: 2}] - - def test_len(self): - length = len(self.history) - assert isinstance(length, int), 'Returned object is not an integer' - assert length == 3, 'Returned length is not correct' - - def test_head(self): - n = 2 - df = self.history.head(n) - assert isinstance(df, pd.DataFrame), 'Returned object is not a DataFrame' - assert len(df) == n, 'Returned DataFrame does not have the correct length' - - def test_tail(self): - n = 2 - df = self.history.tail(n) - assert isinstance(df, pd.DataFrame), 'Returned object is not a DataFrame' - assert len(df) == n, 'Returned DataFrame does not have the correct length' - - def test_repr(self): - representation = self.history.__repr__() - assert isinstance(representation, str), 'Returned object is not a string' - - def test_repr_html(self): - html_representation = self.history._repr_html_() - assert isinstance(html_representation, str), 'Returned object is not a string' - - - def test_select(self): - df = self.history.select(where="time > 1", start=0, stop=10, columns=['event', 'time']) - assert isinstance(df, pd.DataFrame), 'Returned object is not a DataFrame' - assert len(df.columns) == 2, 'Returned DataFrame does not have the correct number of columns' - assert len(df) == 2, 'Returned DataFrame does not have the correct length' - - - -# Test the Oneline class -class TestOneline: - - @classmethod - def setup_class(cls): - # Set up a test HDF5 file using pandas HDFStore - cls.filename = 'test_oneline.h5' - with pd.HDFStore(cls.filename, 'w') as store: - # Create a oneline dataframe - oneline_data = pd.DataFrame({'time': [1, 2, 3], 'S1_mass_i': ['30','30', '70']}) - store.append('oneline', oneline_data, data_columns=True) - - def setup_method(self): - self.oneline = Oneline(self.filename, verbose=False, chunksize=10000) - - def test_init(self): - oneline = Oneline(self.filename, verbose=True, chunksize=5000) - - assert oneline.filename == self.filename, 'Filename is not set correctly' - assert oneline.verbose == True, 'Verbose attribute is not set correctly' - assert oneline.chunksize == 5000, 'Chunksize attribute is not set correctly' - assert oneline.number_of_systems == 3, 'Number of systems attribute is not set correctly' - assert oneline.columns.to_list() == ['time', 'S1_mass_i'], 'Columns attribute is not set correctly' - assert oneline.number_of_systems == 3, 'Number of systems attribute is not set correctly' - - assert isinstance(oneline.indices, np.ndarray), 'Indices attribute is not an ndarray' - np.testing.assert_array_equal(oneline.indices, np.array([0, 1, 2]), 'Indices attribute is not equal to the expected list') - - with pytest.raises(FileNotFoundError): - Oneline('invalid_filename.h5', verbose=False, chunksize=10000) - - def test_getitem_single_index(self): - df = self.oneline[0] - assert isinstance(df, pd.DataFrame), 'Returned object is not a DataFrame' - assert len(df) == 1, 'Returned DataFrame does not have the correct length' - - def test_getitem_multiple_indices(self): - df = self.oneline[[0, 1, 2]] - assert isinstance(df, pd.DataFrame), 'Returned object is not a DataFrame' - assert len(df) == 3, 'Returned DataFrame does not have the correct length' - - def test_getitem_index_array(self): - indices = np.array([0, 1, 2]) - df = self.oneline[indices] - assert isinstance(df, pd.DataFrame), 'Returned object is not a DataFrame' - assert len(df) == 3, 'Returned DataFrame does not have the correct length' - pd.testing.assert_frame_equal(df, - pd.DataFrame({'time': [1, 2, 3], - 'S1_mass_i': ['30','30', '70']}), - 'Returned DataFrame is not equal to the expected DataFrame') - - def test_getitem_slice(self): - df = self.oneline[0:2] - assert isinstance(df, pd.DataFrame), 'Returned object is not a DataFrame' - assert len(df) == 2, 'Returned DataFrame does not have the correct length' - pd.testing.assert_frame_equal(df, - pd.DataFrame({'time': [1, 2], - 'S1_mass_i': ['30','30']}),) - def test_getitem_endslice(self): - df = self.oneline[:2] - assert isinstance(df, pd.DataFrame), 'Returned object is not a DataFrame' - assert len(df) == 2, 'Returned DataFrame does not have the correct length' - pd.testing.assert_frame_equal(df, - pd.DataFrame({'time': [1, 2], - 'S1_mass_i': ['30','30']}),) - def test_getitem_beginslice(self): - df = self.oneline[1:] - assert isinstance(df, pd.DataFrame), 'Returned object is not a DataFrame' - assert len(df) == 2, 'Returned DataFrame does not have the correct length' - pd.testing.assert_frame_equal(df, - pd.DataFrame(index=[1, 2], - data={'time': [2, 3], - 'S1_mass_i': ['30', '70']}),) - def test_getitem_float_indices(self): - with pytest.raises(ValueError): - self.oneline[[0.5, 1.2]] - - - - def test_getitem_single_column(self): - column = 'time' - df = self.oneline[column] - assert isinstance(df, pd.DataFrame), 'Returned object is not a DataFrame' - assert len(df.columns) == 1, 'Returned DataFrame does not have the correct number of columns' - - def test_getitem_boolean_mask_numpy(self): - mask = (self.oneline['time'] > 1).to_numpy().flatten() - df = self.oneline[mask] - assert isinstance(df, pd.DataFrame), 'Returned object is not a DataFrame' - - def test_getitem_boolean_mask_pandas(self): - mask = self.oneline['time'] > 1 - print(mask) - df = self.oneline[mask] - assert isinstance(df, pd.DataFrame), 'Returned object is not a DataFrame' - - def test_getitem_multiple_columns(self): - columns = ['time', 'S1_mass_i'] - df = self.oneline[columns] - assert isinstance(df, pd.DataFrame), 'Returned object is not a DataFrame' - assert len(df.columns) == 2, 'Returned DataFrame does not have the correct number of columns' - - def test_getitems_multiple_columns_invalid(self): - columns = ['time', 'invalid_column'] - with pytest.raises(ValueError): - self.oneline[columns] - - def test_getitem_invalid_key_type(self): - with pytest.raises(ValueError): - self.oneline[{1: 2}] - - def test_getitem_invalid_key(self): - with pytest.raises(ValueError): - self.oneline['invalid_key'] - - def test_len(self): - length = len(self.oneline) - assert isinstance(length, int), 'Returned object is not an integer' - assert length == 3, 'Returned length is not correct' - - def test_head(self): - n = 2 - df = self.oneline.head(n) - assert isinstance(df, pd.DataFrame), 'Returned object is not a DataFrame' - assert len(df) == n, 'Returned DataFrame does not have the correct length' - - def test_tail(self): - n = 2 - df = self.oneline.tail(n) - assert isinstance(df, pd.DataFrame), 'Returned object is not a DataFrame' - assert len(df) == n, 'Returned DataFrame does not have the correct length' - - def test_repr(self): - representation = self.oneline.__repr__() - assert isinstance(representation, str), 'Returned object is not a string' - - def test_repr_html(self): - html_representation = self.oneline._repr_html_() - assert isinstance(html_representation, str), 'Returned object is not a string' - - def test_select(self): - df = self.oneline.select(where="time > 1", start=0, stop=10, columns=['S1_mass_i', 'time']) - assert isinstance(df, pd.DataFrame), 'Returned object is not a DataFrame' - assert len(df.columns) == 2, 'Returned DataFrame does not have the correct number of columns' - assert len(df) == 2, 'Returned DataFrame does not have the correct length' - - - @classmethod - def teardown_class(cls): - os.remove(cls.filename) - -# Test the PopulationIO class -class TestPopulationIO: - - def setup_method(self): - self.filename = "test_population.h5" - - def teardown_method(self): - if os.path.exists(self.filename): - os.remove(self.filename) - - def test_init(self): - pop_io = PopulationIO() - assert pop_io.verbose == False, "Verbose attribute is not set correctly" - - - def test_invalid_filename(self): - pop_io = PopulationIO() - with pytest.raises(ValueError): - pop_io._load_metadata("invalid_filename") - - def test_save_and_load_mass_per_met(self): - population_io = PopulationIO() - population_io.verbose = True - population_io.mass_per_metallicity = pd.DataFrame({"metallicity": [0.02, 0.04], "mass": [1.0, 2.0]}) - population_io._save_mass_per_metallicity(self.filename) - - loaded_io = PopulationIO() - loaded_io.verbose = True - loaded_io._load_mass_per_metallicity(self.filename) - pd.testing.assert_frame_equal(population_io.mass_per_metallicity, loaded_io.mass_per_metallicity) - - def test_save_and_load_ini_params(self): - population_io = PopulationIO() - population_io.ini_params = {i:10 for i in parameter_array} - population_io._save_ini_params(self.filename) - - loaded_io = PopulationIO() - loaded_io._load_ini_params(self.filename) - - assert population_io.ini_params == loaded_io.ini_params, "Loaded ini_params are not equal to the saved ini_params" - - def test_save_and_load_metadata(self): - pop_io = PopulationIO() - pop_io.verbose = True - pop_io.ini_params = {i:10 for i in parameter_array} - pop_io._save_ini_params(self.filename) - pop_io.mass_per_metallicity = pd.DataFrame({"metallicity": [0.02, 0.04], "mass": [1.0, 2.0]}) - pop_io._save_mass_per_metallicity(self.filename) - - load_io = PopulationIO() - load_io._load_metadata(self.filename) - - assert pop_io.ini_params == load_io.ini_params, "Loaded ini_params are not equal to the saved ini_params" - assert pop_io.mass_per_metallicity.equals(load_io.mass_per_metallicity), "Loaded mass_per_metallicity is not equal to the saved mass_per_metallicity" - - - -class TestPopulation: - def setup_method(self): - pass - - def teardown_method(self): - # Clean up any resources used by the test - pass - - def setup_class(self): - self.filename1 = "no_mass_per_met_population.h5" - self.filename2 = "history_population.h5" - self.filename3 = "oneline_population.h5" - self.history_data = pd.DataFrame({'time': [1, 2, 3], 'event': ['ZAMS','oRLO1', 'CEE']}) - self.oneline_data = pd.DataFrame({'time': [1, 2, 3], 'S1_mass_i': [30, 30, 70], 'S2_mass_i': [30, 30, 70.]}) - self.formation_channels = pd.DataFrame({'channel': ['channel1', 'channel2', 'channel3'], 'channel_debug':['debug1', 'debug2', 'debug3']}) - - # create a file with only history and oneline data - with pd.HDFStore(self.filename1, 'w') as store: - store.append('history',self.history_data, data_columns=True) - store.append('oneline', self.oneline_data, data_columns=True) - - with pd.HDFStore(self.filename2, 'w') as store: - store.append('history', self.history_data, data_columns=True) - - with pd.HDFStore(self.filename3, 'w') as store: - store.append('oneline', self.oneline_data, data_columns=True) - - def teardown_class(self): - if os.path.exists(self.filename1): - os.remove(self.filename1) - if os.path.exists(self.filename2): - os.remove(self.filename2) - if os.path.exists(self.filename3): - os.remove(self.filename3) - - @pytest.fixture - def mass_per_met_pop(self): - self.filename = "mass_per_met_population.h5" - with pd.HDFStore(self.filename, 'w') as store: - store.append('history', self.history_data, data_columns=True) - store.append('oneline', self.oneline_data, data_columns=True) - - pop = Population(self.filename, verbose=True, metallicity=0.02, ini_file=PATH_TO_POSYDON+'/posydon/popsyn/population_params_default.ini') - yield - if os.path.exists(self.filename): - os.remove(self.filename) - - @pytest.fixture - def no_mass_per_met_pop(self): - self.filename = "no_mass_per_met_population.h5" - with pd.HDFStore(self.filename, 'w') as store: - store.append('history', self.history_data, data_columns=True) - store.append('oneline', self.oneline_data, data_columns=True) - yield - if os.path.exists(self.filename): - os.remove(self.filename) - - @pytest.fixture - def mass_per_met_pop_channels(self): - self.filename = "mass_per_met_population.h5" - with pd.HDFStore(self.filename, 'w') as store: - store.append('history', self.history_data, data_columns=True) - store.append('oneline', self.oneline_data, data_columns=True) - store.append('formation_channels', self.formation_channels, data_columns=True) - pop = Population(self.filename, verbose=True, metallicity=0.02, ini_file=PATH_TO_POSYDON+'/posydon/popsyn/population_params_default.ini') - yield - if os.path.exists(self.filename): - os.remove(self.filename) - - @pytest.fixture - def clean_up_selection_file(self): - self.outfile = "test_selection.h5" - yield - if os.path.exists(self.outfile): - os.remove(self.outfile) - - def test_init_invalid_file(self): - with pytest.raises(ValueError): - pop = Population('invalid_filename') - - def test_init_no_history(self): - with pytest.raises(ValueError): - pop = Population(self.filename3) - - def test_init_no_oneline(self): - with pytest.raises(ValueError): - pop = Population(self.filename2) - - def test_init_no_mass_per_met(self): - with pytest.raises(ValueError): - pop = Population(self.filename1, verbose=True) - - - def test_init_mass_per_met_calc(self, no_mass_per_met_pop: None): - pop = Population(self.filename, verbose=True, metallicity=1., ini_file=PATH_TO_POSYDON+'/posydon/popsyn/population_params_default.ini') - # check that the history and oneline data are read correctly - pd.testing.assert_frame_equal(pop.history[:], self.history_data) - pd.testing.assert_frame_equal(pop.oneline[:], self.oneline_data) - assert pop.solar_metallicities == [1.] - assert pop.metallicities == [1*Zsun] - pd.testing.assert_frame_equal(pop.mass_per_metallicity, pd.DataFrame(index=[1.], data={'simulated_mass': [260.], 'underlying_mass': [1462.194834], 'number_of_systems': [3]})) - - pop = Population(self.filename, verbose=True, metallicity=1., ini_file=PATH_TO_POSYDON+'/posydon/popsyn/population_params_default.ini') - # check that the history and oneline data are the same - pd.testing.assert_frame_equal(pop.history[:], self.history_data) - pd.testing.assert_frame_equal(pop.oneline[:], self.oneline_data) - assert pop.solar_metallicities == [1.] - assert pop.metallicities == [1*Zsun] - pd.testing.assert_frame_equal(pop.mass_per_metallicity, pd.DataFrame(index=[1.], data={'simulated_mass': [260.], 'underlying_mass': [1462.194834], 'number_of_systems': [3]})) - - - def test_init(self,mass_per_met_pop: None): - pop = Population(self.filename) - # check that the history and oneline data are read correctly - pd.testing.assert_frame_equal(pop.history[:], self.history_data) - pd.testing.assert_frame_equal(pop.oneline[:], self.oneline_data) - assert pop.metallicities == [0.02*Zsun] - assert pop.solar_metallicities == [0.02] - tmp_df = pd.DataFrame(index=[0, 1, 2], data={'index': [1, 1, 1]}) - tmp_df.index.name = 'index' - pd.testing.assert_frame_equal(pop.history_lengths, tmp_df) - pd.testing.assert_frame_equal(pop.mass_per_metallicity, pd.DataFrame(index=[0.02], data={'simulated_mass': [260.], 'underlying_mass': [1462.194834], 'number_of_systems': [3]})) - - - def test_read_formation_channels(self, mass_per_met_pop_channels: None): - pop = Population(self.filename) - # check that the formation channels are read correctly - assert pop.formation_channels.equals(self.formation_channels) - - def test_export_selection(self, mass_per_met_pop: None, clean_up_selection_file: None): - selection = [1, 2] - chunksize = 1000 - pop = Population(self.filename) - pop.export_selection(selection, self.outfile, chunksize) - assert os.path.exists(self.outfile) - assert pd.read_hdf(self.outfile, 'history').shape[0] == 2 - assert pd.read_hdf(self.outfile, 'oneline').shape[0] == 2 - - def test_bad_name_export_selection(self, mass_per_met_pop: None): - selection = [1, 2] - chunksize = 1000 - pop = Population(self.filename) - with pytest.raises(ValueError): - pop.export_selection(selection, 'test_selection.csv', history_chunksize=chunksize) - - def test_append_selection(self, mass_per_met_pop: None, clean_up_selection_file: None): - selection = [1, 2] - chunksize = 1000 - pop = Population(self.filename) - pop.export_selection(selection, self.outfile, overwrite=True, history_chunksize=chunksize) - pop.export_selection(selection, self.outfile, overwrite=False, history_chunksize=chunksize) - - assert pd.read_hdf(self.outfile, 'history').shape[0] == 4 - assert pd.read_hdf(self.outfile, 'oneline').shape[0] == 4 - - def test_no_formation_channels(self, mass_per_met_pop: None): - pop = Population(self.filename, verbose=True) - assert pop.formation_channels is None - - def test_len(self, mass_per_met_pop: None): - pop = Population(self.filename) - assert len(pop) == 3 - - def test_columns(self, mass_per_met_pop: None): - pop = Population(self.filename) - columns = pop.columns - - assert columns['history'].tolist() == self.history_data.columns.tolist() - assert columns['oneline'].tolist() == self.oneline_data.columns.tolist() - - - - - # Test formation channel calculation, I need a specific test file for this, - # since it requires specific columns to be present in the oneline and history dataframes - - # Test create_transient_population method requires a specific test file for this, - # since it requires specific columns to be present in the oneline and history dataframes - - -class TestTransientPopulation: - pass - # to implement - - - -class TestRates: - pass - # to implement - -# Run the tests - -if __name__ == '__main__': - pytest.main() diff --git a/posydon/tests/visualization/test_VHdiagram.py b/posydon/tests/visualization/test_VHdiagram.py deleted file mode 100644 index 4e0fa81276..0000000000 --- a/posydon/tests/visualization/test_VHdiagram.py +++ /dev/null @@ -1,93 +0,0 @@ -import os -import unittest -from unittest.mock import patch - -from PyQt5.QtCore import QTimer -from PyQt5.QtWidgets import QApplication - -from posydon.config import PATH_TO_POSYDON -from posydon.visualization.VH_diagram.Presenter import Presenter, PresenterMode - -PATH_TO_DATASET = os.path.join( - PATH_TO_POSYDON, - "posydon", - "tests", - "data", - "POSYDON-UNIT-TESTS", - "visualization", - "20000_binaries.csv.gz" -) - -# https://stackoverflow.com/questions/60692711/cant-create-python-qapplication-in-github-action - -# if not os.path.exists(PATH_TO_DATASET): -# raise ValueError("Dataset for unit testing (VH diagram) was not found!") -# -# -# class TestVHdiagram(unittest.TestCase): -# def test_termination_detailled_view(self): -# app = QApplication.instance() -# if not app: -# app = QApplication([]) -# -# presenter = Presenter(PATH_TO_DATASET) -# -# with patch('PyQt5.QtWidgets.QMainWindow.show') as show_patch: -# presenter.present(19628, PresenterMode.DETAILED) -# -# QTimer.singleShot(0, lambda : presenter.close() ) -# -# app.exec_() -# -# assert show_patch.called -# -# def test_termination_reduced_view(self): -# app = QApplication.instance() -# if not app: -# app = QApplication([]) -# -# presenter = Presenter(PATH_TO_DATASET) -# -# with patch('PyQt5.QtWidgets.QMainWindow.show') as show_patch: -# presenter.present(19628, PresenterMode.REDUCED) -# -# QTimer.singleShot(0, lambda : presenter.close() ) -# -# app.exec_() -# -# assert show_patch.called -# -# def test_termination_simplified_view(self): -# app = QApplication.instance() -# if not app: -# app = QApplication([]) -# -# presenter = Presenter(PATH_TO_DATASET) -# -# with patch('PyQt5.QtWidgets.QMainWindow.show') as show_patch: -# presenter.present(19628, PresenterMode.SIMPLIFIED) -# -# QTimer.singleShot(0, lambda : presenter.close() ) -# -# app.exec_() -# -# assert show_patch.called -# -# def test_termination_diagram_view(self): -# app = QApplication.instance() -# if not app: -# app = QApplication([]) -# -# presenter = Presenter(PATH_TO_DATASET) -# -# with patch('PyQt5.QtWidgets.QMainWindow.show') as show_patch: -# presenter.present(19628, PresenterMode.DIAGRAM) -# -# QTimer.singleShot(0, lambda : presenter.close() ) -# -# app.exec_() -# -# assert show_patch.called -# -# if __name__ == "__main__": -# unittest.main() diff --git a/posydon/tests/visualization/test_plot1D.py b/posydon/tests/visualization/test_plot1D.py deleted file mode 100644 index 5621832351..0000000000 --- a/posydon/tests/visualization/test_plot1D.py +++ /dev/null @@ -1,110 +0,0 @@ -import os -import unittest -from unittest.mock import patch - -from posydon.config import PATH_TO_POSYDON -from posydon.grids.psygrid import PSyGrid - -PATH_TO_GRID = os.path.join( - PATH_TO_POSYDON, - "posydon/tests/data/POSYDON-UNIT-TESTS/" - "visualization/grid_unit_test_plot.h5" -) - -if not os.path.exists(PATH_TO_GRID): - raise ValueError("Test grid for unit testing was not found!") - - -class TestPlot1D(unittest.TestCase): - def test_one_track_one_var_plotting(self): - grid = PSyGrid(PATH_TO_GRID) - with patch('matplotlib.pyplot.show') as show_patch: - grid.plot(42, - "star_age", - "center_he4", - history="history1", - **{'show_fig': True}) - assert show_patch.called - with patch('matplotlib.pyplot.show') as show_patch: - grid.plot(42, - "age", - "star_1_mass", - history="binary_history", - **{'show_fig': True}) - assert show_patch.called - - def test_one_track_many_vars_plotting(self): - grid = PSyGrid(PATH_TO_GRID) - with patch('matplotlib.pyplot.show') as show_patch: - grid.plot(42, - "star_age", ["center_he4", "log_LHe"], - history="history1", - **{'show_fig': True}) - assert show_patch.called - with patch('matplotlib.pyplot.show') as show_patch: - grid.plot( - 42, - "age", - ["star_1_mass", "binary_separation", "rl_relative_overflow_1"], - history="binary_history", - **{'show_fig': True}) - assert show_patch.called - - def test_many_tracks_one_var_plotting(self): - grid = PSyGrid(PATH_TO_GRID) - with patch('matplotlib.pyplot.show') as show_patch: - grid.plot([42, 43, 44], - "star_age", - "center_he4", - history="history1", - **{'show_fig': True}) - assert show_patch.called - - def test_many_tracks_many_vars_plotting(self): - grid = PSyGrid(PATH_TO_GRID) - with patch('matplotlib.pyplot.show') as show_patch: - grid.plot( - [42, 43, 44], - "age", - ["star_1_mass", "binary_separation", "rl_relative_overflow_1"], - history="binary_history", - **{'show_fig': True}) - assert show_patch.called - - def test_one_track_one_var_extra_var_color_plotting(self): - grid = PSyGrid(PATH_TO_GRID) - with patch('matplotlib.pyplot.show') as show_patch: - grid.plot(42, - "age", - "star_1_mass", - "period_days", - history="binary_history", - **{'show_fig': True}) - assert show_patch.called - - def test_many_tracks_one_var_extra_var_color_plotting(self): - grid = PSyGrid(PATH_TO_GRID) - with patch('matplotlib.pyplot.show') as show_patch: - grid.plot([42, 43], - "age", - "star_1_mass", - "period_days", - history="binary_history", - **{'show_fig': True}) - assert show_patch.called - - def test_one_track_HR_plotting(self): - grid = PSyGrid(PATH_TO_GRID) - with patch('matplotlib.pyplot.show') as show_patch: - grid.HR(42, history="history1", **{'show_fig': True}) - assert show_patch.called - - def test_many_tracks_HR_plotting(self): - grid = PSyGrid(PATH_TO_GRID) - with patch('matplotlib.pyplot.show') as show_patch: - grid.HR([42, 43, 44], history="history1", **{'show_fig': True}) - assert show_patch.called - - -if __name__ == "__main__": - unittest.main() diff --git a/posydon/tests/visualization/test_plot2D.py b/posydon/tests/visualization/test_plot2D.py deleted file mode 100644 index 637ebb72e8..0000000000 --- a/posydon/tests/visualization/test_plot2D.py +++ /dev/null @@ -1,157 +0,0 @@ -import os -import unittest -from unittest.mock import patch - -from posydon.config import PATH_TO_POSYDON -from posydon.grids.psygrid import PSyGrid - -PATH_TO_GRID = os.path.join( - PATH_TO_POSYDON, - "posydon/tests/data/POSYDON-UNIT-TESTS/" - "visualization/grid_unit_test_plot.h5" -) -if not os.path.exists(PATH_TO_GRID): - raise ValueError("Test grid for unit testing was not found!") - -# class TestPlot2D(unittest.TestCase): -# def test_termination_flag_1_plotting(self): -# grid = PSyGrid(PATH_TO_GRID) -# with patch('matplotlib.pyplot.show') as show_patch: -# grid.plot2D("star_1_mass", -# "period_days", -# "star_1_mass", -# termination_flag="termination_flag_1", -# grid_3D=True, -# slice_3D_var_str="star_2_mass", -# slice_3D_var_range=(2.5, 3.0), -# **{'show_fig': True}) -# assert show_patch.called -# with patch('matplotlib.pyplot.show') as show_patch: -# grid.plot2D("star_1_mass", -# "period_days", -# "c_core_mass", -# termination_flag="termination_flag_1", -# grid_3D=True, -# slice_3D_var_str="star_2_mass", -# slice_3D_var_range=(2.5, 3.0), -# **{'show_fig': True}) -# assert show_patch.called -# with patch('matplotlib.pyplot.show') as show_patch: -# grid.plot2D("star_1_mass", -# "period_days", -# "binary_separation", -# termination_flag="termination_flag_1", -# grid_3D=True, -# slice_3D_var_str="star_2_mass", -# slice_3D_var_range=(2.5, 3.0), -# **{'show_fig': True}) -# assert show_patch.called -# -# def test_termination_flag_2_plotting(self): -# grid = PSyGrid(PATH_TO_GRID) -# with patch('matplotlib.pyplot.show') as show_patch: -# grid.plot2D("star_1_mass", -# "period_days", -# None, -# termination_flag="termination_flag_2", -# grid_3D=True, -# slice_3D_var_str="star_2_mass", -# slice_3D_var_range=(2.5, 3.0), -# **{'show_fig': True}) -# assert show_patch.called -# -# def test_termination_flag_3_plotting(self): -# grid = PSyGrid(PATH_TO_GRID) -# with patch('matplotlib.pyplot.show') as show_patch: -# grid.plot2D("star_1_mass", -# "period_days", -# None, -# termination_flag="termination_flag_3", -# grid_3D=True, -# slice_3D_var_str="star_2_mass", -# slice_3D_var_range=(2.5, 3.0), -# **{'show_fig': True}) -# assert show_patch.called -# -# def test_termination_flag_4_plotting(self): -# grid = PSyGrid(PATH_TO_GRID) -# with patch('matplotlib.pyplot.show') as show_patch: -# grid.plot2D("star_1_mass", -# "period_days", -# None, -# termination_flag="termination_flag_4", -# grid_3D=True, -# slice_3D_var_str="star_2_mass", -# slice_3D_var_range=(2.5, 3.0), -# **{'show_fig': True}) -# assert show_patch.called -# -# def test_all_termination_flags_plotting(self): -# grid = PSyGrid(PATH_TO_GRID) -# with patch('matplotlib.pyplot.show') as show_patch: -# grid.plot2D("star_1_mass", -# "period_days", -# "binary_separation", -# termination_flag="all", -# grid_3D=True, -# slice_3D_var_str="star_2_mass", -# slice_3D_var_range=(2.5, 3.0), -# **{'show_fig': True}) -# assert show_patch.called -# -# def test_RLO_plotting(self): -# grid = PSyGrid(PATH_TO_GRID) -# # with patch('matplotlib.pyplot.show') as show_patch: -# # grid.plot2D("star_1_mass", -# # "period_days", -# # "c_core_mass", -# # termination_flag="termination_flag_1", -# # grid_3D=True, -# # slice_3D_var_str="star_2_mass", -# # slice_3D_var_range=(2.5, 3.0), -# # slice_at_RLO=True, -# # **{ -# # 'show_fig': True -# # }) -# # assert show_patch.called -# # with patch('matplotlib.pyplot.show') as show_patch: -# # grid.plot2D("star_1_mass", -# # "period_days", -# # "star_1_mass", -# # termination_flag="termination_flag_1", -# # grid_3D=True, -# # slice_3D_var_str="star_2_mass", -# # slice_3D_var_range=(2.5, 3.0), -# # slice_at_RLO=True, -# # **{ -# # 'show_fig': True -# # }) -# # assert show_patch.called -# -# def test_extra_grid_plotting(self): -# grid = PSyGrid(PATH_TO_GRID) -# with patch('matplotlib.pyplot.show') as show_patch: -# grid.plot2D("star_1_mass", -# "period_days", -# "star_1_mass", -# termination_flag="termination_flag_1", -# grid_3D=True, -# slice_3D_var_str="star_2_mass", -# slice_3D_var_range=(2.5, 3.0), -# extra_grid=grid, -# **{'show_fig': True}) -# with patch('matplotlib.pyplot.show') as show_patch: -# grid.plot2D("star_1_mass", -# "period_days", -# "star_1_mass", -# termination_flag="termination_flag_1", -# grid_3D=True, -# slice_3D_var_str="star_2_mass", -# slice_3D_var_range=(2.5, 3.0), -# extra_grid=grid, -# **{'show_fig': True}) -# assert show_patch.called -# -# -# if __name__ == "__main__": -# unittest.main() diff --git a/pyproject.toml b/pyproject.toml index ed69b72524..746f28c5c8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,30 +2,21 @@ # details: https://packaging.python.org/en/latest/guides/writing-pyproject-toml [build-system] -requires = ["setuptools >= 76.0.0", "versioneer"] +requires = ["setuptools >= 76.0.0", "setuptools-scm >= 8.0"] build-backend = "setuptools.build_meta" [project] -dynamic = [ - "description", - "license", - "version", - "requires-python", - "classifiers", - "dependencies", - "optional-dependencies", -] name = "posydon" -#description = "POSYDON the Next Generation of Population Synthesis" +description = "POSYDON the Next Generation of Population Synthesis" authors = [ {name = "POSYDON Collaboration", email = "posydon.team@gmail.com"}, ] maintainers = [ {name = "POSYDON Collaboration", email = "posydon.team@gmail.com"}, ] -#license = 'GPLv3+' -#version = "2.0.0.dev" -#requires-python = ">=3.11, <3.12" +license = "GPL-3.0-or-later" +dynamic = ["version"] +requires-python = ">=3.11, <3.12" readme = "README.md" keywords = [ "POSYDON", @@ -34,60 +25,67 @@ keywords = [ "Population Synthesis", "MESA", ] -#classifiers = [ -# 'Development Status :: 4 - Beta', -# 'Intended Audience :: Science/Research', -# 'Intended Audience :: End Users/Desktop', -# 'Topic :: Scientific/Engineering', -# 'Topic :: Scientific/Engineering :: Astronomy', -# 'Topic :: Scientific/Engineering :: Physics', -# 'Programming Language :: Python', -# 'Programming Language :: Python :: 3.11', -# 'Operating System :: POSIX', -# 'Operating System :: Unix', -# 'Operating System :: MacOS', -# 'Natural Language :: English', -# 'License :: OSI Approved :: GNU General Public License v3 (GPLv3+)', -#] -#dependencies = [ -# 'numpy < 2.0.0, >= 1.24.2', -# 'scipy <= 1.14.1, >= 1.10.1', -# 'iminuit <= 2.30.1, >= 2.21.3', -# 'configparser <= 7.1.0, >= 5.3.0', -# 'astropy <= 6.1.6, >= 5.2.2', -# 'pandas <= 2.2.3, >= 2.0.0', -# 'scikit-learn == 1.2.2', -# 'matplotlib <= 3.9.2, >= 3.9.0', -# 'matplotlib-label-lines <= 0.7.0, >= 0.5.2', -# 'h5py <= 3.12.1, >= 3.8.0', -# 'psutil <= 6.1.0, >= 5.9.4', -# 'tqdm <= 4.67.0, >= 4.65.0', -# 'tables <= 3.10.1, >= 3.8.0', -# 'progressbar2 <= 4.5.0, >= 4.2.0', -# 'hurry.filesize <= 0.9, >= 0.9', -# 'python-dotenv <= 1.0.1, >= 1.0.0', -#] +classifiers = [ + "Development Status :: 4 - Beta", + "Intended Audience :: Science/Research", + "Intended Audience :: End Users/Desktop", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Astronomy", + "Topic :: Scientific/Engineering :: Physics", + "Programming Language :: Python", + "Programming Language :: Python :: 3.11", + "Operating System :: POSIX", + "Operating System :: Unix", + "Operating System :: MacOS", + "Natural Language :: English", +] +dependencies = [ + "numpy >= 1.24.2, < 2.0.0", + "scipy >= 1.10.1, <= 1.14.1", + "iminuit >= 2.21.3, <= 2.30.1", + "configparser >= 5.3.0, <= 7.1.0", + "astropy >= 5.2.2, <= 6.1.6", + "pandas >= 2.0.0, <= 2.2.3", + "scikit-learn == 1.2.2", + "matplotlib >= 3.9.0, <= 3.9.2", + "matplotlib-label-lines >= 0.5.2, <= 0.7.0", + "h5py >= 3.8.0, <= 3.12.1", + "psutil >= 5.9.4, <= 6.1.0", + "tqdm >= 4.65.0, <= 4.67.0", + "tables >= 3.8.0, <= 3.10.1", + "progressbar2 >= 4.2.0, <= 4.5.0", + "hurry.filesize >= 0.9, <= 0.9", + "python-dotenv >= 1.0.0, <= 1.0.1", +] -#[project.optional-dependencies] -#doc = [ -# 'ipython', -# 'sphinx >= 8.2.2', -# 'numpydoc', -# 'sphinx_rtd_theme', -# 'sphinxcontrib_programoutput', -# 'PSphinxTheme', -# 'nbsphinx', -# 'pandoc' -#] -#vis = [ -# 'PyQt5 <= 5.15.11, >= 5.15.9' -#] -#ml = [ -# 'tensorflow >= 2.13.0' -#] -#hpc = [ -# 'mpi4py >= 3.0.3' -#] +[project.optional-dependencies] +doc = [ + "ipython", + "sphinx >= 8.2.2", + "numpydoc", + "sphinx_rtd_theme", + "sphinxcontrib_programoutput", + "PSphinxTheme", + "nbsphinx", + "pandoc", +] +vis = [ + "PyQt5 >= 5.15.9, <= 5.15.11", +] +ml = [ + "tensorflow >= 2.13.0", +] +hpc = [ + "mpi4py >= 3.0.3", +] +dev = [ + "pre-commit >= 3.7.0", + "isort >= 5.13.2", +] +test = [ + "pytest >= 7.3.1", + "pytest-cov >= 4.0.0", +] [project.urls] Homepage = "https://posydon.org" @@ -95,3 +93,42 @@ Documentation = "https://posydon.org/POSYDON" Repository = "https://github.com/POSYDON-code/POSYDON.git" Issues = "https://github.com/POSYDON-code/POSYDON/issues" Changelog = "https://github.com/POSYDON-code/POSYDON/releases" + +[tool.setuptools.packages.find] +include = ["posydon*"] + +[tool.setuptools] +include-package-data = true +script-files = [ + "bin/compress-mesa", + "bin/get-posydon-data", + "bin/posydon-popsyn", + "bin/posydon-run-grid", + "bin/posydon-run-pipeline", + "bin/posydon-setup-grid", + "bin/posydon-setup-pipeline", +] + +[tool.setuptools_scm] +# setuptools-scm will automatically determine version from git tags + +[tool.pytest.ini_options] +addopts = "--verbose -r s --cov --cov-branch --cov-report=term-missing --cov-fail-under=100" +testpaths = ["posydon/unit_tests"] + +[tool.coverage.run] +branch = true +source = [ + "posydon.config", + "posydon.utils", + "posydon.grids", + "posydon.popsyn.IMFs", + "posydon.popsyn.norm_pop", + "posydon.popsyn.distributions", + "posydon.popsyn.star_formation_history", + "posydon.CLI", +] + +[tool.coverage.report] +fail_under = 100 +show_missing = true diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index 3a53872f29..0000000000 --- a/setup.cfg +++ /dev/null @@ -1,22 +0,0 @@ -[aliases] -test = pytest - -[bdist_wheel] -universal = 1 - -[tool:pytest] -addopts = --verbose -r s - -[versioneer] -VCS = git -style = pep440 -versionfile_source = posydon/_version.py -versionfile_build = posydon/_version.py -tag_prefix = v -parentdir_prefix = - -[coverage:run] -source = posydon -omit = - posydon/tests/* - posydon/_version.py diff --git a/setup.py b/setup.py index d8e1e7e709..adc1fb2678 100644 --- a/setup.py +++ b/setup.py @@ -1,174 +1,19 @@ -"""Setup the posydon package.""" +"""Minimal setup.py for POSYDON package. -from __future__ import print_function +All configuration is in pyproject.toml. +This file only handles optional sphinx documentation builds. +""" -import glob -import os.path -import sys - -sys.path.insert(0, os.path.dirname(__file__)) - -import versioneer +from setuptools import setup +# Optional: Add sphinx documentation build command cmdclass = {} - - -# VERSIONING - -__version__ = versioneer.get_version() -cmdclass.update(versioneer.get_cmdclass()) - - -# TOGGLE WRAPPING C/C++ OR FORTRAN - -WRAP_C_CPP_OR_FORTRAN = False - -if WRAP_C_CPP_OR_FORTRAN: - from distutils.command.sdist import sdist - - try: - from numpy.distutils.core import Extension, setup - except ImportError: - raise ImportError("Building fortran extensions requires numpy.") - - cmdclass["sdist"] = sdist -else: - from setuptools import find_packages, setup - - -# DOCUMENTATION - -# import sphinx commands try: from sphinx.setup_command import BuildDoc + cmdclass["build_sphinx"] = BuildDoc except ImportError: pass -else: - cmdclass["build_sphinx"] = BuildDoc - -# read description -with open("README.md", "rb") as f: - longdesc = "f.read().decode().strip()" - - -# DEPENDENCIES -setup_requires = [ - 'setuptools >= 76.0.0', -] -if 'test' in sys.argv: - setup_requires += [ - 'pytest-runner', - ] - - -# These pretty common requirement are commented out. Various syntax types -# are all used in the example below for specifying specific version of the -# packages that are compatbile with your software. -# TODO NOTE: before the v2.0.0 code release, we should froze the versions -# the correct way to do this is to make sure that they are available on -# conda and pip for all platforms we support (see prerequisites doc page). -install_requires = [ - 'numpy >= 1.24.2, < 2.0.0', - 'scipy >= 1.10.1, <= 1.14.1', - 'iminuit >= 2.21.3, <= 2.30.1', - 'configparser >= 5.3.0, <= 7.1.0', - 'astropy >= 5.2.2, <= 6.1.6', - 'pandas >= 2.0.0, <= 2.2.3', - 'scikit-learn == 1.2.2', - 'matplotlib >= 3.9.0, <= 3.9.2', - 'matplotlib-label-lines >= 0.5.2, <= 0.7.0', - 'h5py >= 3.8.0, <= 3.12.1', - 'psutil >= 5.9.4, <= 6.1.0', - 'tqdm >= 4.65.0, <= 4.67.0', - 'tables >= 3.8.0, <= 3.10.1', - 'progressbar2 >= 4.2.0, <= 4.5.0', # for downloading data - 'hurry.filesize >= 0.9, <= 0.9', - 'python-dotenv >= 1.0.0, <= 1.0.1', -] - -tests_require = [ - "pytest >= 7.3.1", - "pytest-cov >= 4.0.0", -] - -# For documentation -extras_require = { - # to build documentation - "doc": [ - "ipython", - "sphinx >= 8.2.2", - "numpydoc", - "sphinx_rtd_theme", - "sphinxcontrib_programoutput", - "PSphinxTheme", - "nbsphinx", - "pandoc", - ], - # for experimental visualization features, e.g. VDH diagrams - "vis": ["PyQt5 >= 5.15.9, <= 5.15.11"], - # for profile machine learning features, e.g. profile interpolation - "ml": ["tensorflow >= 2.13.0"], - # for running population synthesis on HPC facilities - "hpc": ["mpi4py >= 3.0.3"], - # development tooling - 'dev': [ - 'pre-commit >= 3.7.0', - 'isort >= 5.13.2', - ], -} - -# RUN SETUP - -packagenames = find_packages() - -# Executables go in a folder called bin -scripts = glob.glob(os.path.join("bin", "*")) - -PACKAGENAME = "posydon" -DISTNAME = "posydon" -AUTHOR = "POSYDON Collaboration" -AUTHOR_EMAIL = "posydon.team@gmail.com" -LICENSE = "GPLv3+" -DESCRIPTION = "POSYDON the Next Generation of Population Synthesis" -GITHUBURL = "https://github.com/POSYDON-code/POSYDON" - -# Additional included files via include_package_data are defined in MANIFEST.in -setup( - name=DISTNAME, - provides=[PACKAGENAME], - version=__version__, - description=DESCRIPTION, - long_description=longdesc, - long_description_content_type="text/markdown", - ext_modules=[wrapper] if WRAP_C_CPP_OR_FORTRAN else [], - author=AUTHOR, - author_email=AUTHOR_EMAIL, - license=LICENSE, - packages=packagenames, - include_package_data=True, - cmdclass=cmdclass, - url=GITHUBURL, - scripts=scripts, - setup_requires=setup_requires, - install_requires=install_requires, - tests_require=tests_require, - extras_require=extras_require, - python_requires=">=3.11, <3.12", - use_2to3=False, - classifiers=[ - "Development Status :: 4 - Beta", - "Intended Audience :: Science/Research", - "Intended Audience :: End Users/Desktop", - "Topic :: Scientific/Engineering", - "Topic :: Scientific/Engineering :: Astronomy", - "Topic :: Scientific/Engineering :: Physics", - "Programming Language :: Python", - "Programming Language :: Python :: 3.11", - "Operating System :: POSIX", - "Operating System :: Unix", - "Operating System :: MacOS", - "Natural Language :: English", - "License :: OSI Approved :: GNU General Public License v3 (GPLv3+)", - ], -) +# Minimal setup call - all metadata including version is in pyproject.toml +# Version is automatically determined by setuptools-scm from git tags +setup(cmdclass=cmdclass) diff --git a/versioneer.py b/versioneer.py deleted file mode 100644 index ef293d9b3e..0000000000 --- a/versioneer.py +++ /dev/null @@ -1,1825 +0,0 @@ - -# Version: 0.18 - -"""The Versioneer - like a rocketeer, but for versions. - -The Versioneer -============== - -* like a rocketeer, but for versions! -* https://github.com/warner/python-versioneer -* Brian Warner -* License: Public Domain -* Compatible With: python2.6, 2.7, 3.2, 3.3, 3.4, 3.5, 3.6, and pypy -* [![Latest Version] -(https://pypip.in/version/versioneer/badge.svg?style=flat) -](https://pypi.python.org/pypi/versioneer/) -* [![Build Status] -(https://travis-ci.org/warner/python-versioneer.png?branch=master) -](https://travis-ci.org/warner/python-versioneer) - -This is a tool for managing a recorded version number in distutils-based -python projects. The goal is to remove the tedious and error-prone "update -the embedded version string" step from your release process. Making a new -release should be as easy as recording a new tag in your version-control -system, and maybe making new tarballs. - - -## Quick Install - -* `pip install versioneer` to somewhere to your $PATH -* add a `[versioneer]` section to your setup.cfg (see below) -* run `versioneer install` in your source tree, commit the results - -## Version Identifiers - -Source trees come from a variety of places: - -* a version-control system checkout (mostly used by developers) -* a nightly tarball, produced by build automation -* a snapshot tarball, produced by a web-based VCS browser, like github's - "tarball from tag" feature -* a release tarball, produced by "setup.py sdist", distributed through PyPI - -Within each source tree, the version identifier (either a string or a number, -this tool is format-agnostic) can come from a variety of places: - -* ask the VCS tool itself, e.g. "git describe" (for checkouts), which knows - about recent "tags" and an absolute revision-id -* the name of the directory into which the tarball was unpacked -* an expanded VCS keyword ($Id$, etc) -* a `_version.py` created by some earlier build step - -For released software, the version identifier is closely related to a VCS -tag. Some projects use tag names that include more than just the version -string (e.g. "myproject-1.2" instead of just "1.2"), in which case the tool -needs to strip the tag prefix to extract the version identifier. For -unreleased software (between tags), the version identifier should provide -enough information to help developers recreate the same tree, while also -giving them an idea of roughly how old the tree is (after version 1.2, before -version 1.3). Many VCS systems can report a description that captures this, -for example `git describe --tags --dirty --always` reports things like -"0.7-1-g574ab98-dirty" to indicate that the checkout is one revision past the -0.7 tag, has a unique revision id of "574ab98", and is "dirty" (it has -uncommitted changes. - -The version identifier is used for multiple purposes: - -* to allow the module to self-identify its version: `myproject.__version__` -* to choose a name and prefix for a 'setup.py sdist' tarball - -## Theory of Operation - -Versioneer works by adding a special `_version.py` file into your source -tree, where your `__init__.py` can import it. This `_version.py` knows how to -dynamically ask the VCS tool for version information at import time. - -`_version.py` also contains `$Revision$` markers, and the installation -process marks `_version.py` to have this marker rewritten with a tag name -during the `git archive` command. As a result, generated tarballs will -contain enough information to get the proper version. - -To allow `setup.py` to compute a version too, a `versioneer.py` is added to -the top level of your source tree, next to `setup.py` and the `setup.cfg` -that configures it. This overrides several distutils/setuptools commands to -compute the version when invoked, and changes `setup.py build` and `setup.py -sdist` to replace `_version.py` with a small static file that contains just -the generated version data. - -## Installation - -See [INSTALL.md](./INSTALL.md) for detailed installation instructions. - -## Version-String Flavors - -Code which uses Versioneer can learn about its version string at runtime by -importing `_version` from your main `__init__.py` file and running the -`get_versions()` function. From the "outside" (e.g. in `setup.py`), you can -import the top-level `versioneer.py` and run `get_versions()`. - -Both functions return a dictionary with different flavors of version -information: - -* `['version']`: A condensed version string, rendered using the selected - style. This is the most commonly used value for the project's version - string. The default "pep440" style yields strings like `0.11`, - `0.11+2.g1076c97`, or `0.11+2.g1076c97.dirty`. See the "Styles" section - below for alternative styles. - -* `['full-revisionid']`: detailed revision identifier. For Git, this is the - full SHA1 commit id, e.g. "1076c978a8d3cfc70f408fe5974aa6c092c949ac". - -* `['date']`: Date and time of the latest `HEAD` commit. For Git, it is the - commit date in ISO 8601 format. This will be None if the date is not - available. - -* `['dirty']`: a boolean, True if the tree has uncommitted changes. Note that - this is only accurate if run in a VCS checkout, otherwise it is likely to - be False or None - -* `['error']`: if the version string could not be computed, this will be set - to a string describing the problem, otherwise it will be None. It may be - useful to throw an exception in setup.py if this is set, to avoid e.g. - creating tarballs with a version string of "unknown". - -Some variants are more useful than others. Including `full-revisionid` in a -bug report should allow developers to reconstruct the exact code being tested -(or indicate the presence of local changes that should be shared with the -developers). `version` is suitable for display in an "about" box or a CLI -`--version` output: it can be easily compared against release notes and lists -of bugs fixed in various releases. - -The installer adds the following text to your `__init__.py` to place a basic -version in `YOURPROJECT.__version__`: - - from ._version import get_versions - __version__ = get_versions()['version'] - del get_versions - -## Styles - -The setup.cfg `style=` configuration controls how the VCS information is -rendered into a version string. - -The default style, "pep440", produces a PEP440-compliant string, equal to the -un-prefixed tag name for actual releases, and containing an additional "local -version" section with more detail for in-between builds. For Git, this is -TAG[+DISTANCE.gHEX[.dirty]] , using information from `git describe --tags ---dirty --always`. For example "0.11+2.g1076c97.dirty" indicates that the -tree is like the "1076c97" commit but has uncommitted changes (".dirty"), and -that this commit is two revisions ("+2") beyond the "0.11" tag. For released -software (exactly equal to a known tag), the identifier will only contain the -stripped tag, e.g. "0.11". - -Other styles are available. See [details.md](details.md) in the Versioneer -source tree for descriptions. - -## Debugging - -Versioneer tries to avoid fatal errors: if something goes wrong, it will tend -to return a version of "0+unknown". To investigate the problem, run `setup.py -version`, which will run the version-lookup code in a verbose mode, and will -display the full contents of `get_versions()` (including the `error` string, -which may help identify what went wrong). - -## Known Limitations - -Some situations are known to cause problems for Versioneer. This details the -most significant ones. More can be found on Github -[issues page](https://github.com/warner/python-versioneer/issues). - -### Subprojects - -Versioneer has limited support for source trees in which `setup.py` is not in -the root directory (e.g. `setup.py` and `.git/` are *not* siblings). The are -two common reasons why `setup.py` might not be in the root: - -* Source trees which contain multiple subprojects, such as - [Buildbot](https://github.com/buildbot/buildbot), which contains both - "master" and "slave" subprojects, each with their own `setup.py`, - `setup.cfg`, and `tox.ini`. Projects like these produce multiple PyPI - distributions (and upload multiple independently-installable tarballs). -* Source trees whose main purpose is to contain a C library, but which also - provide bindings to Python (and perhaps other langauges) in subdirectories. - -Versioneer will look for `.git` in parent directories, and most operations -should get the right version string. However `pip` and `setuptools` have bugs -and implementation details which frequently cause `pip install .` from a -subproject directory to fail to find a correct version string (so it usually -defaults to `0+unknown`). - -`pip install --editable .` should work correctly. `setup.py install` might -work too. - -Pip-8.1.1 is known to have this problem, but hopefully it will get fixed in -some later version. - -[Bug #38](https://github.com/warner/python-versioneer/issues/38) is tracking -this issue. The discussion in -[PR #61](https://github.com/warner/python-versioneer/pull/61) describes the -issue from the Versioneer side in more detail. -[pip PR#3176](https://github.com/pypa/pip/pull/3176) and -[pip PR#3615](https://github.com/pypa/pip/pull/3615) contain work to improve -pip to let Versioneer work correctly. - -Versioneer-0.16 and earlier only looked for a `.git` directory next to the -`setup.cfg`, so subprojects were completely unsupported with those releases. - -### Editable installs with setuptools <= 18.5 - -`setup.py develop` and `pip install --editable .` allow you to install a -project into a virtualenv once, then continue editing the source code (and -test) without re-installing after every change. - -"Entry-point scripts" (`setup(entry_points={"console_scripts": ..})`) are a -convenient way to specify executable scripts that should be installed along -with the python package. - -These both work as expected when using modern setuptools. When using -setuptools-18.5 or earlier, however, certain operations will cause -`pkg_resources.DistributionNotFound` errors when running the entrypoint -script, which must be resolved by re-installing the package. This happens -when the install happens with one version, then the egg_info data is -regenerated while a different version is checked out. Many setup.py commands -cause egg_info to be rebuilt (including `sdist`, `wheel`, and installing into -a different virtualenv), so this can be surprising. - -[Bug #83](https://github.com/warner/python-versioneer/issues/83) describes -this one, but upgrading to a newer version of setuptools should probably -resolve it. - -### Unicode version strings - -While Versioneer works (and is continually tested) with both Python 2 and -Python 3, it is not entirely consistent with bytes-vs-unicode distinctions. -Newer releases probably generate unicode version strings on py2. It's not -clear that this is wrong, but it may be surprising for applications when then -write these strings to a network connection or include them in bytes-oriented -APIs like cryptographic checksums. - -[Bug #71](https://github.com/warner/python-versioneer/issues/71) investigates -this question. - - -## Updating Versioneer - -To upgrade your project to a new release of Versioneer, do the following: - -* install the new Versioneer (`pip install -U versioneer` or equivalent) -* edit `setup.cfg`, if necessary, to include any new configuration settings - indicated by the release notes. See [UPGRADING](./UPGRADING.md) for details. -* re-run `versioneer install` in your source tree, to replace - `SRC/_version.py` -* commit any changed files - -## Future Directions - -This tool is designed to make it easily extended to other version-control -systems: all VCS-specific components are in separate directories like -src/git/ . The top-level `versioneer.py` script is assembled from these -components by running make-versioneer.py . In the future, make-versioneer.py -will take a VCS name as an argument, and will construct a version of -`versioneer.py` that is specific to the given VCS. It might also take the -configuration arguments that are currently provided manually during -installation by editing setup.py . Alternatively, it might go the other -direction and include code from all supported VCS systems, reducing the -number of intermediate scripts. - - -## License - -To make Versioneer easier to embed, all its code is dedicated to the public -domain. The `_version.py` that it creates is also in the public domain. -Specifically, both are released under the Creative Commons "Public Domain -Dedication" license (CC0-1.0), as described in -https://creativecommons.org/publicdomain/zero/1.0/ . - -""" - -from __future__ import print_function - -try: - import configparser -except ImportError: - import ConfigParser as configparser - -import errno -import json -import os -import re -import subprocess -import sys - - -class VersioneerConfig: - """Container for Versioneer configuration parameters.""" - - -def get_root(): - """Get the project root directory. - - We require that all commands are run from the project root, i.e. the - directory that contains setup.py, setup.cfg, and versioneer.py . - """ - root = os.path.realpath(os.path.abspath(os.getcwd())) - setup_py = os.path.join(root, "setup.py") - versioneer_py = os.path.join(root, "versioneer.py") - if not (os.path.exists(setup_py) or os.path.exists(versioneer_py)): - # allow 'python path/to/setup.py COMMAND' - root = os.path.dirname(os.path.realpath(os.path.abspath(sys.argv[0]))) - setup_py = os.path.join(root, "setup.py") - versioneer_py = os.path.join(root, "versioneer.py") - if not (os.path.exists(setup_py) or os.path.exists(versioneer_py)): - err = ("Versioneer was unable to run the project root directory. " - "Versioneer requires setup.py to be executed from " - "its immediate directory (like 'python setup.py COMMAND'), " - "or in a way that lets it use sys.argv[0] to find the root " - "(like 'python path/to/setup.py COMMAND').") - raise VersioneerBadRootError(err) - try: - # Certain runtime workflows (setup.py install/develop in a setuptools - # tree) execute all dependencies in a single python process, so - # "versioneer" may be imported multiple times, and python's shared - # module-import table will cache the first one. So we can't use - # os.path.dirname(__file__), as that will find whichever - # versioneer.py was first imported, even in later projects. - me = os.path.realpath(os.path.abspath(__file__)) - me_dir = os.path.normcase(os.path.splitext(me)[0]) - vsr_dir = os.path.normcase(os.path.splitext(versioneer_py)[0]) - if me_dir != vsr_dir: - print("Warning: build in %s is using versioneer.py from %s" - % (os.path.dirname(me), versioneer_py)) - except NameError: - pass - return root - - -def get_config_from_root(root): - """Read the project setup.cfg file to determine Versioneer config.""" - # This might raise EnvironmentError (if setup.cfg is missing), or - # configparser.NoSectionError (if it lacks a [versioneer] section), or - # configparser.NoOptionError (if it lacks "VCS="). See the docstring at - # the top of versioneer.py for instructions on writing your setup.cfg . - setup_cfg = os.path.join(root, "setup.cfg") - parser = configparser.SafeConfigParser() - with open(setup_cfg, "r") as f: - parser.readfp(f) - VCS = parser.get("versioneer", "VCS") # mandatory - - def get(parser, name): - if parser.has_option("versioneer", name): - return parser.get("versioneer", name) - return None - cfg = VersioneerConfig() - cfg.VCS = VCS - cfg.style = get(parser, "style") or "" - cfg.versionfile_source = get(parser, "versionfile_source") - cfg.versionfile_build = get(parser, "versionfile_build") - cfg.tag_prefix = get(parser, "tag_prefix") - if cfg.tag_prefix in ("''", '""'): - cfg.tag_prefix = "" - cfg.parentdir_prefix = get(parser, "parentdir_prefix") - cfg.verbose = get(parser, "verbose") - return cfg - - -class NotThisMethod(Exception): - """Exception raised if a method is not valid for the current scenario.""" - - -# these dictionaries contain VCS-specific tools -LONG_VERSION_PY = {} -HANDLERS = {} - - -def register_vcs_handler(vcs, method): # decorator - """Get decorator to mark a method as the handler for a particular VCS.""" - def decorate(f): - """Store f in HANDLERS[vcs][method].""" - if vcs not in HANDLERS: - HANDLERS[vcs] = {} - HANDLERS[vcs][method] = f - return f - return decorate - - -def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, - env=None): - """Call the given command(s).""" - assert isinstance(commands, list) - p = None - for c in commands: - try: - dispcmd = str([c] + args) - # remember shell=False, so use git.cmd on windows, not just git - p = subprocess.Popen([c] + args, cwd=cwd, env=env, - stdout=subprocess.PIPE, - stderr=(subprocess.PIPE if hide_stderr - else None)) - break - except EnvironmentError: - e = sys.exc_info()[1] - if e.errno == errno.ENOENT: - continue - if verbose: - print("unable to run %s" % dispcmd) - print(e) - return None, None - else: - if verbose: - print("unable to find command, tried %s" % (commands,)) - return None, None - stdout = p.communicate()[0].strip() - if sys.version_info[0] >= 3: - stdout = stdout.decode() - if p.returncode != 0: - if verbose: - print("unable to run %s (error)" % dispcmd) - print("stdout was %s" % stdout) - return None, p.returncode - return stdout, p.returncode - - -LONG_VERSION_PY['git'] = ''' -# This file helps to compute a version number in source trees obtained from -# git-archive tarball (such as those provided by githubs download-from-tag -# feature). Distribution tarballs (built by setup.py sdist) and build -# directories (produced by setup.py build) will contain a much shorter file -# that just contains the computed version number. - -# This file is released into the public domain. Generated by -# versioneer-0.18 (https://github.com/warner/python-versioneer) - -"""Git implementation of _version.py.""" - -import errno -import os -import re -import subprocess -import sys - - -def get_keywords(): - """Get the keywords needed to look up the version information.""" - # these strings will be replaced by git during git-archive. - # setup.py/versioneer.py will grep for the variable names, so they must - # each be defined on a line of their own. _version.py will just call - # get_keywords(). - git_refnames = "%(DOLLAR)sFormat:%%d%(DOLLAR)s" - git_full = "%(DOLLAR)sFormat:%%H%(DOLLAR)s" - git_date = "%(DOLLAR)sFormat:%%ci%(DOLLAR)s" - keywords = {"refnames": git_refnames, "full": git_full, "date": git_date} - return keywords - - -class VersioneerConfig: - """Container for Versioneer configuration parameters.""" - - -def get_config(): - """Create, populate and return the VersioneerConfig() object.""" - # these strings are filled in when 'setup.py versioneer' creates - # _version.py - cfg = VersioneerConfig() - cfg.VCS = "git" - cfg.style = "%(STYLE)s" - cfg.tag_prefix = "%(TAG_PREFIX)s" - cfg.parentdir_prefix = "%(PARENTDIR_PREFIX)s" - cfg.versionfile_source = "%(VERSIONFILE_SOURCE)s" - cfg.verbose = False - return cfg - - -class NotThisMethod(Exception): - """Exception raised if a method is not valid for the current scenario.""" - - -LONG_VERSION_PY = {} -HANDLERS = {} - - -def register_vcs_handler(vcs, method): # decorator - """Decorator to mark a method as the handler for a particular VCS.""" - def decorate(f): - """Store f in HANDLERS[vcs][method].""" - if vcs not in HANDLERS: - HANDLERS[vcs] = {} - HANDLERS[vcs][method] = f - return f - return decorate - - -def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, - env=None): - """Call the given command(s).""" - assert isinstance(commands, list) - p = None - for c in commands: - try: - dispcmd = str([c] + args) - # remember shell=False, so use git.cmd on windows, not just git - p = subprocess.Popen([c] + args, cwd=cwd, env=env, - stdout=subprocess.PIPE, - stderr=(subprocess.PIPE if hide_stderr - else None)) - break - except EnvironmentError: - e = sys.exc_info()[1] - if e.errno == errno.ENOENT: - continue - if verbose: - print("unable to run %%s" %% dispcmd) - print(e) - return None, None - else: - if verbose: - print("unable to find command, tried %%s" %% (commands,)) - return None, None - stdout = p.communicate()[0].strip() - if sys.version_info[0] >= 3: - stdout = stdout.decode() - if p.returncode != 0: - if verbose: - print("unable to run %%s (error)" %% dispcmd) - print("stdout was %%s" %% stdout) - return None, p.returncode - return stdout, p.returncode - - -def versions_from_parentdir(parentdir_prefix, root, verbose): - """Try to determine the version from the parent directory name. - - Source tarballs conventionally unpack into a directory that includes both - the project name and a version string. We will also support searching up - two directory levels for an appropriately named parent directory - """ - rootdirs = [] - - for i in range(3): - dirname = os.path.basename(root) - if dirname.startswith(parentdir_prefix): - return {"version": dirname[len(parentdir_prefix):], - "full-revisionid": None, - "dirty": False, "error": None, "date": None} - else: - rootdirs.append(root) - root = os.path.dirname(root) # up a level - - if verbose: - print("Tried directories %%s but none started with prefix %%s" %% - (str(rootdirs), parentdir_prefix)) - raise NotThisMethod("rootdir doesn't start with parentdir_prefix") - - -@register_vcs_handler("git", "get_keywords") -def git_get_keywords(versionfile_abs): - """Extract version information from the given file.""" - # the code embedded in _version.py can just fetch the value of these - # keywords. When used from setup.py, we don't want to import _version.py, - # so we do it with a regexp instead. This function is not used from - # _version.py. - keywords = {} - try: - f = open(versionfile_abs, "r") - for line in f.readlines(): - if line.strip().startswith("git_refnames ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["refnames"] = mo.group(1) - if line.strip().startswith("git_full ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["full"] = mo.group(1) - if line.strip().startswith("git_date ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["date"] = mo.group(1) - f.close() - except EnvironmentError: - pass - return keywords - - -@register_vcs_handler("git", "keywords") -def git_versions_from_keywords(keywords, tag_prefix, verbose): - """Get version information from git keywords.""" - if not keywords: - raise NotThisMethod("no keywords at all, weird") - date = keywords.get("date") - if date is not None: - # git-2.2.0 added "%%cI", which expands to an ISO-8601 -compliant - # datestamp. However we prefer "%%ci" (which expands to an "ISO-8601 - # -like" string, which we must then edit to make compliant), because - # it's been around since git-1.5.3, and it's too difficult to - # discover which version we're using, or to work around using an - # older one. - date = date.strip().replace(" ", "T", 1).replace(" ", "", 1) - refnames = keywords["refnames"].strip() - if refnames.startswith("$Format"): - if verbose: - print("keywords are unexpanded, not using") - raise NotThisMethod("unexpanded keywords, not a git-archive tarball") - refs = set([r.strip() for r in refnames.strip("()").split(",")]) - # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of - # just "foo-1.0". If we see a "tag: " prefix, prefer those. - TAG = "tag: " - tags = set([r[len(TAG):] for r in refs if r.startswith(TAG)]) - if not tags: - # Either we're using git < 1.8.3, or there really are no tags. We use - # a heuristic: assume all version tags have a digit. The old git %%d - # expansion behaves like git log --decorate=short and strips out the - # refs/heads/ and refs/tags/ prefixes that would let us distinguish - # between branches and tags. By ignoring refnames without digits, we - # filter out many common branch names like "release" and - # "stabilization", as well as "HEAD" and "master". - tags = set([r for r in refs if re.search(r'\d', r)]) - if verbose: - print("discarding '%%s', no digits" %% ",".join(refs - tags)) - if verbose: - print("likely tags: %%s" %% ",".join(sorted(tags))) - for ref in sorted(tags): - # sorting will prefer e.g. "2.0" over "2.0rc1" - if ref.startswith(tag_prefix): - r = ref[len(tag_prefix):] - if verbose: - print("picking %%s" %% r) - return {"version": r, - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": None, - "date": date} - # no suitable tags, so version is "0+unknown", but full hex is still there - if verbose: - print("no suitable tags, using unknown + full revision id") - return {"version": "0+unknown", - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": "no suitable tags", "date": None} - - -@register_vcs_handler("git", "pieces_from_vcs") -def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): - """Get version from 'git describe' in the root of the source tree. - - This only gets called if the git-archive 'subst' keywords were *not* - expanded, and _version.py hasn't already been rewritten with a short - version string, meaning we're inside a checked out source tree. - """ - GITS = ["git"] - if sys.platform == "win32": - GITS = ["git.cmd", "git.exe"] - - out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, - hide_stderr=True) - if rc != 0: - if verbose: - print("Directory %%s not under git control" %% root) - raise NotThisMethod("'git rev-parse --git-dir' returned error") - - # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] - # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = run_command(GITS, ["describe", "--tags", "--dirty", - "--always", "--long", - "--match", "%%s*" %% tag_prefix], - cwd=root) - # --long was added in git-1.5.5 - if describe_out is None: - raise NotThisMethod("'git describe' failed") - describe_out = describe_out.strip() - full_out, rc = run_command(GITS, ["rev-parse", "HEAD"], cwd=root) - if full_out is None: - raise NotThisMethod("'git rev-parse' failed") - full_out = full_out.strip() - - pieces = {} - pieces["long"] = full_out - pieces["short"] = full_out[:7] # maybe improved later - pieces["error"] = None - - # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] - # TAG might have hyphens. - git_describe = describe_out - - # look for -dirty suffix - dirty = git_describe.endswith("-dirty") - pieces["dirty"] = dirty - if dirty: - git_describe = git_describe[:git_describe.rindex("-dirty")] - - # now we have TAG-NUM-gHEX or HEX - - if "-" in git_describe: - # TAG-NUM-gHEX - mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) - if not mo: - # unparseable. Maybe git-describe is misbehaving? - pieces["error"] = ("unable to parse git-describe output: '%%s'" - %% describe_out) - return pieces - - # tag - full_tag = mo.group(1) - if not full_tag.startswith(tag_prefix): - if verbose: - fmt = "tag '%%s' doesn't start with prefix '%%s'" - print(fmt %% (full_tag, tag_prefix)) - pieces["error"] = ("tag '%%s' doesn't start with prefix '%%s'" - %% (full_tag, tag_prefix)) - return pieces - pieces["closest-tag"] = full_tag[len(tag_prefix):] - - # distance: number of commits since tag - pieces["distance"] = int(mo.group(2)) - - # commit: short hex revision ID - pieces["short"] = mo.group(3) - - else: - # HEX: no tags - pieces["closest-tag"] = None - count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], - cwd=root) - pieces["distance"] = int(count_out) # total number of commits - - # commit date: see ISO-8601 comment in git_versions_from_keywords() - date = run_command(GITS, ["show", "-s", "--format=%%ci", "HEAD"], - cwd=root)[0].strip() - pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) - - return pieces - - -def plus_or_dot(pieces): - """Return a + if we don't already have one, else return a .""" - if "+" in pieces.get("closest-tag", ""): - return "." - return "+" - - -def render_pep440(pieces): - """Build up version string, with post-release "local version identifier". - - Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you - get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty - - Exceptions: - 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += plus_or_dot(pieces) - rendered += "%%d.g%%s" %% (pieces["distance"], pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - else: - # exception #1 - rendered = "0+untagged.%%d.g%%s" %% (pieces["distance"], - pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - return rendered - - -def render_pep440_pre(pieces): - """TAG[.post.devDISTANCE] -- No -dirty. - - Exceptions: - 1: no tags. 0.post.devDISTANCE - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"]: - rendered += ".post.dev%%d" %% pieces["distance"] - else: - # exception #1 - rendered = "0.post.dev%%d" %% pieces["distance"] - return rendered - - -def render_pep440_post(pieces): - """TAG[.postDISTANCE[.dev0]+gHEX] . - - The ".dev0" means dirty. Note that .dev0 sorts backwards - (a dirty tree will appear "older" than the corresponding clean one), - but you shouldn't be releasing software with -dirty anyways. - - Exceptions: - 1: no tags. 0.postDISTANCE[.dev0] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%%d" %% pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - rendered += plus_or_dot(pieces) - rendered += "g%%s" %% pieces["short"] - else: - # exception #1 - rendered = "0.post%%d" %% pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - rendered += "+g%%s" %% pieces["short"] - return rendered - - -def render_pep440_old(pieces): - """TAG[.postDISTANCE[.dev0]] . - - The ".dev0" means dirty. - - Eexceptions: - 1: no tags. 0.postDISTANCE[.dev0] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%%d" %% pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - else: - # exception #1 - rendered = "0.post%%d" %% pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - return rendered - - -def render_git_describe(pieces): - """TAG[-DISTANCE-gHEX][-dirty]. - - Like 'git describe --tags --dirty --always'. - - Exceptions: - 1: no tags. HEX[-dirty] (note: no 'g' prefix) - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"]: - rendered += "-%%d-g%%s" %% (pieces["distance"], pieces["short"]) - else: - # exception #1 - rendered = pieces["short"] - if pieces["dirty"]: - rendered += "-dirty" - return rendered - - -def render_git_describe_long(pieces): - """TAG-DISTANCE-gHEX[-dirty]. - - Like 'git describe --tags --dirty --always -long'. - The distance/hash is unconditional. - - Exceptions: - 1: no tags. HEX[-dirty] (note: no 'g' prefix) - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - rendered += "-%%d-g%%s" %% (pieces["distance"], pieces["short"]) - else: - # exception #1 - rendered = pieces["short"] - if pieces["dirty"]: - rendered += "-dirty" - return rendered - - -def render(pieces, style): - """Render the given version pieces into the requested style.""" - if pieces["error"]: - return {"version": "unknown", - "full-revisionid": pieces.get("long"), - "dirty": None, - "error": pieces["error"], - "date": None} - - if not style or style == "default": - style = "pep440" # the default - - if style == "pep440": - rendered = render_pep440(pieces) - elif style == "pep440-pre": - rendered = render_pep440_pre(pieces) - elif style == "pep440-post": - rendered = render_pep440_post(pieces) - elif style == "pep440-old": - rendered = render_pep440_old(pieces) - elif style == "git-describe": - rendered = render_git_describe(pieces) - elif style == "git-describe-long": - rendered = render_git_describe_long(pieces) - else: - raise ValueError("unknown style '%%s'" %% style) - - return {"version": rendered, "full-revisionid": pieces["long"], - "dirty": pieces["dirty"], "error": None, - "date": pieces.get("date")} - - -def get_versions(): - """Get version information or return default if unable to do so.""" - # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have - # __file__, we can work backwards from there to the root. Some - # py2exe/bbfreeze/non-CPython implementations don't do __file__, in which - # case we can only use expanded keywords. - - cfg = get_config() - verbose = cfg.verbose - - try: - return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, - verbose) - except NotThisMethod: - pass - - try: - root = os.path.realpath(__file__) - # versionfile_source is the relative path from the top of the source - # tree (where the .git directory might live) to this file. Invert - # this to find the root from __file__. - for i in cfg.versionfile_source.split('/'): - root = os.path.dirname(root) - except NameError: - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, - "error": "unable to find root of source tree", - "date": None} - - try: - pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose) - return render(pieces, cfg.style) - except NotThisMethod: - pass - - try: - if cfg.parentdir_prefix: - return versions_from_parentdir(cfg.parentdir_prefix, root, verbose) - except NotThisMethod: - pass - - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, - "error": "unable to compute version", "date": None} -''' - - -@register_vcs_handler("git", "get_keywords") -def git_get_keywords(versionfile_abs): - """Extract version information from the given file.""" - # the code embedded in _version.py can just fetch the value of these - # keywords. When used from setup.py, we don't want to import _version.py, - # so we do it with a regexp instead. This function is not used from - # _version.py. - keywords = {} - try: - f = open(versionfile_abs, "r") - for line in f.readlines(): - if line.strip().startswith("git_refnames ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["refnames"] = mo.group(1) - if line.strip().startswith("git_full ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["full"] = mo.group(1) - if line.strip().startswith("git_date ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["date"] = mo.group(1) - f.close() - except EnvironmentError: - pass - return keywords - - -@register_vcs_handler("git", "keywords") -def git_versions_from_keywords(keywords, tag_prefix, verbose): - """Get version information from git keywords.""" - if not keywords: - raise NotThisMethod("no keywords at all, weird") - date = keywords.get("date") - if date is not None: - # git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant - # datestamp. However we prefer "%ci" (which expands to an "ISO-8601 - # -like" string, which we must then edit to make compliant), because - # it's been around since git-1.5.3, and it's too difficult to - # discover which version we're using, or to work around using an - # older one. - date = date.strip().replace(" ", "T", 1).replace(" ", "", 1) - refnames = keywords["refnames"].strip() - if refnames.startswith("$Format"): - if verbose: - print("keywords are unexpanded, not using") - raise NotThisMethod("unexpanded keywords, not a git-archive tarball") - refs = set([r.strip() for r in refnames.strip("()").split(",")]) - # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of - # just "foo-1.0". If we see a "tag: " prefix, prefer those. - TAG = "tag: " - tags = set([r[len(TAG):] for r in refs if r.startswith(TAG)]) - if not tags: - # Either we're using git < 1.8.3, or there really are no tags. We use - # a heuristic: assume all version tags have a digit. The old git %d - # expansion behaves like git log --decorate=short and strips out the - # refs/heads/ and refs/tags/ prefixes that would let us distinguish - # between branches and tags. By ignoring refnames without digits, we - # filter out many common branch names like "release" and - # "stabilization", as well as "HEAD" and "master". - tags = set([r for r in refs if re.search(r'\d', r)]) - if verbose: - print("discarding '%s', no digits" % ",".join(refs - tags)) - if verbose: - print("likely tags: %s" % ",".join(sorted(tags))) - for ref in sorted(tags): - # sorting will prefer e.g. "2.0" over "2.0rc1" - if ref.startswith(tag_prefix): - r = ref[len(tag_prefix):] - if verbose: - print("picking %s" % r) - return {"version": r, - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": None, - "date": date} - # no suitable tags, so version is "0+unknown", but full hex is still there - if verbose: - print("no suitable tags, using unknown + full revision id") - return {"version": "0+unknown", - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": "no suitable tags", "date": None} - - -@register_vcs_handler("git", "pieces_from_vcs") -def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): - """Get version from 'git describe' in the root of the source tree. - - This only gets called if the git-archive 'subst' keywords were *not* - expanded, and _version.py hasn't already been rewritten with a short - version string, meaning we're inside a checked out source tree. - """ - GITS = ["git"] - if sys.platform == "win32": - GITS = ["git.cmd", "git.exe"] - - out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, - hide_stderr=True) - if rc != 0: - if verbose: - print("Directory %s not under git control" % root) - raise NotThisMethod("'git rev-parse --git-dir' returned error") - - # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] - # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = run_command(GITS, ["describe", "--tags", "--dirty", - "--always", "--long", - "--match", "%s*" % tag_prefix], - cwd=root) - # --long was added in git-1.5.5 - if describe_out is None: - raise NotThisMethod("'git describe' failed") - describe_out = describe_out.strip() - full_out, rc = run_command(GITS, ["rev-parse", "HEAD"], cwd=root) - if full_out is None: - raise NotThisMethod("'git rev-parse' failed") - full_out = full_out.strip() - - pieces = {} - pieces["long"] = full_out - pieces["short"] = full_out[:7] # maybe improved later - pieces["error"] = None - - # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] - # TAG might have hyphens. - git_describe = describe_out - - # look for -dirty suffix - dirty = git_describe.endswith("-dirty") - pieces["dirty"] = dirty - if dirty: - git_describe = git_describe[:git_describe.rindex("-dirty")] - - # now we have TAG-NUM-gHEX or HEX - - if "-" in git_describe: - # TAG-NUM-gHEX - mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) - if not mo: - # unparseable. Maybe git-describe is misbehaving? - pieces["error"] = ("unable to parse git-describe output: '%s'" - % describe_out) - return pieces - - # tag - full_tag = mo.group(1) - if not full_tag.startswith(tag_prefix): - if verbose: - fmt = "tag '%s' doesn't start with prefix '%s'" - print(fmt % (full_tag, tag_prefix)) - pieces["error"] = ("tag '%s' doesn't start with prefix '%s'" - % (full_tag, tag_prefix)) - return pieces - pieces["closest-tag"] = full_tag[len(tag_prefix):] - - # distance: number of commits since tag - pieces["distance"] = int(mo.group(2)) - - # commit: short hex revision ID - pieces["short"] = mo.group(3) - - else: - # HEX: no tags - pieces["closest-tag"] = None - count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], - cwd=root) - pieces["distance"] = int(count_out) # total number of commits - - # commit date: see ISO-8601 comment in git_versions_from_keywords() - date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"], - cwd=root)[0].strip() - pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) - - return pieces - - -def do_vcs_install(manifest_in, versionfile_source, ipy): - """Git-specific installation logic for Versioneer. - - For Git, this means creating/changing .gitattributes to mark _version.py - for export-subst keyword substitution. - """ - GITS = ["git"] - if sys.platform == "win32": - GITS = ["git.cmd", "git.exe"] - files = [manifest_in, versionfile_source] - if ipy: - files.append(ipy) - try: - me = __file__ - if me.endswith(".pyc") or me.endswith(".pyo"): - me = os.path.splitext(me)[0] + ".py" - versioneer_file = os.path.relpath(me) - except NameError: - versioneer_file = "versioneer.py" - files.append(versioneer_file) - present = False - try: - f = open(".gitattributes", "r") - for line in f.readlines(): - if line.strip().startswith(versionfile_source): - if "export-subst" in line.strip().split()[1:]: - present = True - f.close() - except EnvironmentError: - pass - if not present: - f = open(".gitattributes", "a+") - f.write("%s export-subst\n" % versionfile_source) - f.close() - files.append(".gitattributes") - run_command(GITS, ["add", "--"] + files) - - -def versions_from_parentdir(parentdir_prefix, root, verbose): - """Try to determine the version from the parent directory name. - - Source tarballs conventionally unpack into a directory that includes both - the project name and a version string. We will also support searching up - two directory levels for an appropriately named parent directory - """ - rootdirs = [] - - for i in range(3): - dirname = os.path.basename(root) - if dirname.startswith(parentdir_prefix): - return {"version": dirname[len(parentdir_prefix):], - "full-revisionid": None, - "dirty": False, "error": None, "date": None} - else: - rootdirs.append(root) - root = os.path.dirname(root) # up a level - - if verbose: - print("Tried directories %s but none started with prefix %s" % - (str(rootdirs), parentdir_prefix)) - raise NotThisMethod("rootdir doesn't start with parentdir_prefix") - - -SHORT_VERSION_PY = """ -# This file was generated by 'versioneer.py' (0.18) from -# revision-control system data, or from the parent directory name of an -# unpacked source archive. Distribution tarballs contain a pre-generated copy -# of this file. - -import json - -version_json = ''' -%s -''' # END VERSION_JSON - - -def get_versions(): - return json.loads(version_json) -""" - - -def versions_from_file(filename): - """Try to determine the version from _version.py if present.""" - try: - with open(filename) as f: - contents = f.read() - except EnvironmentError: - raise NotThisMethod("unable to read _version.py") - mo = re.search(r"version_json = '''\n(.*)''' # END VERSION_JSON", - contents, re.M | re.S) - if not mo: - mo = re.search(r"version_json = '''\r\n(.*)''' # END VERSION_JSON", - contents, re.M | re.S) - if not mo: - raise NotThisMethod("no version_json in _version.py") - return json.loads(mo.group(1)) - - -def write_to_version_file(filename, versions): - """Write the given version number to the given _version.py file.""" - os.unlink(filename) - contents = json.dumps(versions, sort_keys=True, - indent=1, separators=(",", ": ")) - with open(filename, "w") as f: - f.write(SHORT_VERSION_PY % contents) - - print("set %s to '%s'" % (filename, versions["version"])) - - -def plus_or_dot(pieces): - """Return a + if we don't already have one, else return a .""" - if "+" in pieces.get("closest-tag", ""): - return "." - return "+" - - -def render_pep440(pieces): - """Build up version string, with post-release "local version identifier". - - Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you - get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty - - Exceptions: - 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += plus_or_dot(pieces) - rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - else: - # exception #1 - rendered = "0+untagged.%d.g%s" % (pieces["distance"], - pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - return rendered - - -def render_pep440_pre(pieces): - """TAG[.post.devDISTANCE] -- No -dirty. - - Exceptions: - 1: no tags. 0.post.devDISTANCE - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"]: - rendered += ".post.dev%d" % pieces["distance"] - else: - # exception #1 - rendered = "0.post.dev%d" % pieces["distance"] - return rendered - - -def render_pep440_post(pieces): - """TAG[.postDISTANCE[.dev0]+gHEX] . - - The ".dev0" means dirty. Note that .dev0 sorts backwards - (a dirty tree will appear "older" than the corresponding clean one), - but you shouldn't be releasing software with -dirty anyways. - - Exceptions: - 1: no tags. 0.postDISTANCE[.dev0] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - rendered += plus_or_dot(pieces) - rendered += "g%s" % pieces["short"] - else: - # exception #1 - rendered = "0.post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - rendered += "+g%s" % pieces["short"] - return rendered - - -def render_pep440_old(pieces): - """TAG[.postDISTANCE[.dev0]] . - - The ".dev0" means dirty. - - Eexceptions: - 1: no tags. 0.postDISTANCE[.dev0] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - else: - # exception #1 - rendered = "0.post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - return rendered - - -def render_git_describe(pieces): - """TAG[-DISTANCE-gHEX][-dirty]. - - Like 'git describe --tags --dirty --always'. - - Exceptions: - 1: no tags. HEX[-dirty] (note: no 'g' prefix) - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"]: - rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) - else: - # exception #1 - rendered = pieces["short"] - if pieces["dirty"]: - rendered += "-dirty" - return rendered - - -def render_git_describe_long(pieces): - """TAG-DISTANCE-gHEX[-dirty]. - - Like 'git describe --tags --dirty --always -long'. - The distance/hash is unconditional. - - Exceptions: - 1: no tags. HEX[-dirty] (note: no 'g' prefix) - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) - else: - # exception #1 - rendered = pieces["short"] - if pieces["dirty"]: - rendered += "-dirty" - return rendered - - -def render(pieces, style): - """Render the given version pieces into the requested style.""" - if pieces["error"]: - return {"version": "unknown", - "full-revisionid": pieces.get("long"), - "dirty": None, - "error": pieces["error"], - "date": None} - - if not style or style == "default": - style = "pep440" # the default - - if style == "pep440": - rendered = render_pep440(pieces) - elif style == "pep440-pre": - rendered = render_pep440_pre(pieces) - elif style == "pep440-post": - rendered = render_pep440_post(pieces) - elif style == "pep440-old": - rendered = render_pep440_old(pieces) - elif style == "git-describe": - rendered = render_git_describe(pieces) - elif style == "git-describe-long": - rendered = render_git_describe_long(pieces) - else: - raise ValueError("unknown style '%s'" % style) - - return {"version": rendered, "full-revisionid": pieces["long"], - "dirty": pieces["dirty"], "error": None, - "date": pieces.get("date")} - - -class VersioneerBadRootError(Exception): - """The project root directory is unknown or missing key files.""" - - -def get_versions(verbose=False): - """Get the project version from whatever source is available. - - Returns dict with two keys: 'version' and 'full'. - """ - if "versioneer" in sys.modules: - # see the discussion in cmdclass.py:get_cmdclass() - del sys.modules["versioneer"] - - root = get_root() - cfg = get_config_from_root(root) - - assert cfg.VCS is not None, "please set [versioneer]VCS= in setup.cfg" - handlers = HANDLERS.get(cfg.VCS) - assert handlers, "unrecognized VCS '%s'" % cfg.VCS - verbose = verbose or cfg.verbose - assert cfg.versionfile_source is not None, \ - "please set versioneer.versionfile_source" - assert cfg.tag_prefix is not None, "please set versioneer.tag_prefix" - - versionfile_abs = os.path.join(root, cfg.versionfile_source) - - # extract version from first of: _version.py, VCS command (e.g. 'git - # describe'), parentdir. This is meant to work for developers using a - # source checkout, for users of a tarball created by 'setup.py sdist', - # and for users of a tarball/zipball created by 'git archive' or github's - # download-from-tag feature or the equivalent in other VCSes. - - get_keywords_f = handlers.get("get_keywords") - from_keywords_f = handlers.get("keywords") - if get_keywords_f and from_keywords_f: - try: - keywords = get_keywords_f(versionfile_abs) - ver = from_keywords_f(keywords, cfg.tag_prefix, verbose) - if verbose: - print("got version from expanded keyword %s" % ver) - return ver - except NotThisMethod: - pass - - try: - ver = versions_from_file(versionfile_abs) - if verbose: - print("got version from file %s %s" % (versionfile_abs, ver)) - return ver - except NotThisMethod: - pass - - from_vcs_f = handlers.get("pieces_from_vcs") - if from_vcs_f: - try: - pieces = from_vcs_f(cfg.tag_prefix, root, verbose) - ver = render(pieces, cfg.style) - if verbose: - print("got version from VCS %s" % ver) - return ver - except NotThisMethod: - pass - - try: - if cfg.parentdir_prefix: - ver = versions_from_parentdir(cfg.parentdir_prefix, root, verbose) - if verbose: - print("got version from parentdir %s" % ver) - return ver - except NotThisMethod: - pass - - if verbose: - print("unable to compute version") - - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, "error": "unable to compute version", - "date": None} - - -def get_version(): - """Get the short version string for this project.""" - return get_versions()["version"] - - -def get_cmdclass(): - """Get the custom setuptools/distutils subclasses used by Versioneer.""" - if "versioneer" in sys.modules: - del sys.modules["versioneer"] - # this fixes the "python setup.py develop" case (also 'install' and - # 'easy_install .'), in which subdependencies of the main project are - # built (using setup.py bdist_egg) in the same python process. Assume - # a main project A and a dependency B, which use different versions - # of Versioneer. A's setup.py imports A's Versioneer, leaving it in - # sys.modules by the time B's setup.py is executed, causing B to run - # with the wrong versioneer. Setuptools wraps the sub-dep builds in a - # sandbox that restores sys.modules to it's pre-build state, so the - # parent is protected against the child's "import versioneer". By - # removing ourselves from sys.modules here, before the child build - # happens, we protect the child from the parent's versioneer too. - # Also see https://github.com/warner/python-versioneer/issues/52 - - cmds = {} - - # we add "version" to both distutils and setuptools - from distutils.core import Command - - class cmd_version(Command): - description = "report generated version string" - user_options = [] - boolean_options = [] - - def initialize_options(self): - pass - - def finalize_options(self): - pass - - def run(self): - vers = get_versions(verbose=True) - print("Version: %s" % vers["version"]) - print(" full-revisionid: %s" % vers.get("full-revisionid")) - print(" dirty: %s" % vers.get("dirty")) - print(" date: %s" % vers.get("date")) - if vers["error"]: - print(" error: %s" % vers["error"]) - cmds["version"] = cmd_version - - # we override "build_py" in both distutils and setuptools - # - # most invocation pathways end up running build_py: - # distutils/build -> build_py - # distutils/install -> distutils/build ->.. - # setuptools/bdist_wheel -> distutils/install ->.. - # setuptools/bdist_egg -> distutils/install_lib -> build_py - # setuptools/install -> bdist_egg ->.. - # setuptools/develop -> ? - # pip install: - # copies source tree to a tempdir before running egg_info/etc - # if .git isn't copied too, 'git describe' will fail - # then does setup.py bdist_wheel, or sometimes setup.py install - # setup.py egg_info -> ? - - # we override different "build_py" commands for both environments - if "setuptools" in sys.modules: - from setuptools.command.build_py import build_py as _build_py - else: - from distutils.command.build_py import build_py as _build_py - - class cmd_build_py(_build_py): - def run(self): - root = get_root() - cfg = get_config_from_root(root) - versions = get_versions() - _build_py.run(self) - # now locate _version.py in the new build/ directory and replace - # it with an updated value - if cfg.versionfile_build: - target_versionfile = os.path.join(self.build_lib, - cfg.versionfile_build) - print("UPDATING %s" % target_versionfile) - write_to_version_file(target_versionfile, versions) - cmds["build_py"] = cmd_build_py - - if "cx_Freeze" in sys.modules: # cx_freeze enabled? - from cx_Freeze.dist import build_exe as _build_exe - - # nczeczulin reports that py2exe won't like the pep440-style string - # as FILEVERSION, but it can be used for PRODUCTVERSION, e.g. - # setup(console=[{ - # "version": versioneer.get_version().split("+", 1)[0], # FILEVERSION - # "product_version": versioneer.get_version(), - # ... - - class cmd_build_exe(_build_exe): - def run(self): - root = get_root() - cfg = get_config_from_root(root) - versions = get_versions() - target_versionfile = cfg.versionfile_source - print("UPDATING %s" % target_versionfile) - write_to_version_file(target_versionfile, versions) - - _build_exe.run(self) - os.unlink(target_versionfile) - with open(cfg.versionfile_source, "w") as f: - LONG = LONG_VERSION_PY[cfg.VCS] - f.write(LONG % - {"DOLLAR": "$", - "STYLE": cfg.style, - "TAG_PREFIX": cfg.tag_prefix, - "PARENTDIR_PREFIX": cfg.parentdir_prefix, - "VERSIONFILE_SOURCE": cfg.versionfile_source, - }) - cmds["build_exe"] = cmd_build_exe - del cmds["build_py"] - - if 'py2exe' in sys.modules: # py2exe enabled? - try: - from py2exe.distutils_buildexe import py2exe as _py2exe # py3 - except ImportError: - from py2exe.build_exe import py2exe as _py2exe # py2 - - class cmd_py2exe(_py2exe): - def run(self): - root = get_root() - cfg = get_config_from_root(root) - versions = get_versions() - target_versionfile = cfg.versionfile_source - print("UPDATING %s" % target_versionfile) - write_to_version_file(target_versionfile, versions) - - _py2exe.run(self) - os.unlink(target_versionfile) - with open(cfg.versionfile_source, "w") as f: - LONG = LONG_VERSION_PY[cfg.VCS] - f.write(LONG % - {"DOLLAR": "$", - "STYLE": cfg.style, - "TAG_PREFIX": cfg.tag_prefix, - "PARENTDIR_PREFIX": cfg.parentdir_prefix, - "VERSIONFILE_SOURCE": cfg.versionfile_source, - }) - cmds["py2exe"] = cmd_py2exe - - # we override different "sdist" commands for both environments - if "setuptools" in sys.modules: - from setuptools.command.sdist import sdist as _sdist - else: - from distutils.command.sdist import sdist as _sdist - - class cmd_sdist(_sdist): - def run(self): - versions = get_versions() - self._versioneer_generated_versions = versions - # unless we update this, the command will keep using the old - # version - self.distribution.metadata.version = versions["version"] - return _sdist.run(self) - - def make_release_tree(self, base_dir, files): - root = get_root() - cfg = get_config_from_root(root) - _sdist.make_release_tree(self, base_dir, files) - # now locate _version.py in the new base_dir directory - # (remembering that it may be a hardlink) and replace it with an - # updated value - target_versionfile = os.path.join(base_dir, cfg.versionfile_source) - print("UPDATING %s" % target_versionfile) - write_to_version_file(target_versionfile, - self._versioneer_generated_versions) - cmds["sdist"] = cmd_sdist - - return cmds - - -CONFIG_ERROR = """ -setup.cfg is missing the necessary Versioneer configuration. You need -a section like: - - [versioneer] - VCS = git - style = pep440 - versionfile_source = src/myproject/_version.py - versionfile_build = myproject/_version.py - tag_prefix = - parentdir_prefix = myproject- - -You will also need to edit your setup.py to use the results: - - import versioneer - setup(version=versioneer.get_version(), - cmdclass=versioneer.get_cmdclass(), ...) - -Please read the docstring in ./versioneer.py for configuration instructions, -edit setup.cfg, and re-run the installer or 'python versioneer.py setup'. -""" - -SAMPLE_CONFIG = """ -# See the docstring in versioneer.py for instructions. Note that you must -# re-run 'versioneer.py setup' after changing this section, and commit the -# resulting files. - -[versioneer] -#VCS = git -#style = pep440 -#versionfile_source = -#versionfile_build = -#tag_prefix = -#parentdir_prefix = - -""" - -INIT_PY_SNIPPET = """ -from ._version import get_versions -__version__ = get_versions()['version'] -del get_versions -""" - - -def do_setup(): - """Perform VCS-independent setup function for installing Versioneer.""" - root = get_root() - try: - cfg = get_config_from_root(root) - except (EnvironmentError, configparser.NoSectionError, - configparser.NoOptionError) as e: - if isinstance(e, (EnvironmentError, configparser.NoSectionError)): - print("Adding sample versioneer config to setup.cfg", - file=sys.stderr) - with open(os.path.join(root, "setup.cfg"), "a") as f: - f.write(SAMPLE_CONFIG) - print(CONFIG_ERROR, file=sys.stderr) - return 1 - - print(" creating %s" % cfg.versionfile_source) - with open(cfg.versionfile_source, "w") as f: - LONG = LONG_VERSION_PY[cfg.VCS] - f.write(LONG % {"DOLLAR": "$", - "STYLE": cfg.style, - "TAG_PREFIX": cfg.tag_prefix, - "PARENTDIR_PREFIX": cfg.parentdir_prefix, - "VERSIONFILE_SOURCE": cfg.versionfile_source, - }) - - ipy = os.path.join(os.path.dirname(cfg.versionfile_source), - "__init__.py") - if os.path.exists(ipy): - try: - with open(ipy, "r") as f: - old = f.read() - except EnvironmentError: - old = "" - if INIT_PY_SNIPPET not in old: - print(" appending to %s" % ipy) - with open(ipy, "a") as f: - f.write(INIT_PY_SNIPPET) - else: - print(" %s unmodified" % ipy) - else: - print(" %s doesn't exist, ok" % ipy) - ipy = None - - # Make sure both the top-level "versioneer.py" and versionfile_source - # (PKG/_version.py, used by runtime code) are in MANIFEST.in, so - # they'll be copied into source distributions. Pip won't be able to - # install the package without this. - manifest_in = os.path.join(root, "MANIFEST.in") - simple_includes = set() - try: - with open(manifest_in, "r") as f: - for line in f: - if line.startswith("include "): - for include in line.split()[1:]: - simple_includes.add(include) - except EnvironmentError: - pass - # That doesn't cover everything MANIFEST.in can do - # (http://docs.python.org/2/distutils/sourcedist.html#commands), so - # it might give some false negatives. Appending redundant 'include' - # lines is safe, though. - if "versioneer.py" not in simple_includes: - print(" appending 'versioneer.py' to MANIFEST.in") - with open(manifest_in, "a") as f: - f.write("include versioneer.py\n") - else: - print(" 'versioneer.py' already in MANIFEST.in") - if cfg.versionfile_source not in simple_includes: - print(" appending versionfile_source ('%s') to MANIFEST.in" % - cfg.versionfile_source) - with open(manifest_in, "a") as f: - f.write("include %s\n" % cfg.versionfile_source) - else: - print(" versionfile_source already in MANIFEST.in") - - # Make VCS-specific changes. For git, this means creating/changing - # .gitattributes to mark _version.py for export-subst keyword - # substitution. - do_vcs_install(manifest_in, cfg.versionfile_source, ipy) - return 0 - - -def scan_setup_py(): - """Validate the contents of setup.py against Versioneer's expectations.""" - found = set() - setters = False - errors = 0 - with open("setup.py", "r") as f: - for line in f.readlines(): - if "import versioneer" in line: - found.add("import") - if "versioneer.get_cmdclass()" in line: - found.add("cmdclass") - if "versioneer.get_version()" in line: - found.add("get_version") - if "versioneer.VCS" in line: - setters = True - if "versioneer.versionfile_source" in line: - setters = True - if len(found) != 3: - print("") - print("Your setup.py appears to be missing some important items") - print("(but I might be wrong). Please make sure it has something") - print("roughly like the following:") - print("") - print(" import versioneer") - print(" setup( version=versioneer.get_version(),") - print(" cmdclass=versioneer.get_cmdclass(), ...)") - print("") - errors += 1 - if setters: - print("You should remove lines like 'versioneer.VCS = ' and") - print("'versioneer.versionfile_source = ' . This configuration") - print("now lives in setup.cfg, and should be removed from setup.py") - print("") - errors += 1 - return errors - - -if __name__ == "__main__": - cmd = sys.argv[1] - if cmd == "setup": - errors = do_setup() - errors += scan_setup_py() - if errors: - sys.exit(1)