From a65ebfb17ce69ae21e98379541389420dd27b913 Mon Sep 17 00:00:00 2001 From: Alinson S Xavier Date: Tue, 10 Aug 2021 11:02:02 -0500 Subject: [PATCH] Re-enable half-precision; minor changes to FeaturesExtractor benchmark --- miplearn/features/sample.py | 2 ++ tests/features/test_extractor.py | 14 +++++++------- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/miplearn/features/sample.py b/miplearn/features/sample.py index df50552..39f3261 100644 --- a/miplearn/features/sample.py +++ b/miplearn/features/sample.py @@ -174,6 +174,8 @@ class Hdf5Sample(Sample): if value is None: return self._assert_is_array(value) + if len(value.shape) > 1 and value.dtype.kind == "f": + value = value.astype("float16") if key in self.file: del self.file[key] return self.file.create_dataset(key, data=value, compression="gzip") diff --git a/tests/features/test_extractor.py b/tests/features/test_extractor.py index 056a1bb..19b8767 100644 --- a/tests/features/test_extractor.py +++ b/tests/features/test_extractor.py @@ -11,7 +11,7 @@ import numpy as np import gurobipy as gp from miplearn.features.extractor import FeaturesExtractor -from miplearn.features.sample import MemorySample, Hdf5Sample +from miplearn.features.sample import Hdf5Sample from miplearn.instance.base import Instance from miplearn.solvers.gurobi import GurobiSolver from miplearn.solvers.internal import Variables, Constraints @@ -382,17 +382,17 @@ class MpsInstance(Instance): return gp.read(self.filename) -if __name__ == "__main__": +def main() -> None: solver = GurobiSolver() instance = MpsInstance(sys.argv[1]) solver.set_instance(instance) - lp_stats = solver.solve_lp(tee=True) extractor = FeaturesExtractor(with_lhs=False) sample = Hdf5Sample("tmp/prof.h5", mode="w") + extractor.extract_after_load_features(instance, solver, sample) + lp_stats = solver.solve_lp(tee=True) + extractor.extract_after_lp_features(solver, sample, lp_stats) - def run() -> None: - extractor.extract_after_load_features(instance, solver, sample) - extractor.extract_after_lp_features(solver, sample, lp_stats) - cProfile.run("run()", filename="tmp/prof") +if __name__ == "__main__": + cProfile.run("main()", filename="tmp/prof") os.system("flameprof tmp/prof > tmp/prof.svg")