Re-enable half-precision; minor changes to FeaturesExtractor benchmark

This commit is contained in:
2021-08-10 11:02:02 -05:00
parent 9cfb31bacb
commit a65ebfb17c
2 changed files with 9 additions and 7 deletions

View File

@@ -174,6 +174,8 @@ class Hdf5Sample(Sample):
if value is None: if value is None:
return return
self._assert_is_array(value) self._assert_is_array(value)
if len(value.shape) > 1 and value.dtype.kind == "f":
value = value.astype("float16")
if key in self.file: if key in self.file:
del self.file[key] del self.file[key]
return self.file.create_dataset(key, data=value, compression="gzip") return self.file.create_dataset(key, data=value, compression="gzip")

View File

@@ -11,7 +11,7 @@ import numpy as np
import gurobipy as gp import gurobipy as gp
from miplearn.features.extractor import FeaturesExtractor from miplearn.features.extractor import FeaturesExtractor
from miplearn.features.sample import MemorySample, Hdf5Sample from miplearn.features.sample import Hdf5Sample
from miplearn.instance.base import Instance from miplearn.instance.base import Instance
from miplearn.solvers.gurobi import GurobiSolver from miplearn.solvers.gurobi import GurobiSolver
from miplearn.solvers.internal import Variables, Constraints from miplearn.solvers.internal import Variables, Constraints
@@ -382,17 +382,17 @@ class MpsInstance(Instance):
return gp.read(self.filename) return gp.read(self.filename)
if __name__ == "__main__": def main() -> None:
solver = GurobiSolver() solver = GurobiSolver()
instance = MpsInstance(sys.argv[1]) instance = MpsInstance(sys.argv[1])
solver.set_instance(instance) solver.set_instance(instance)
lp_stats = solver.solve_lp(tee=True)
extractor = FeaturesExtractor(with_lhs=False) extractor = FeaturesExtractor(with_lhs=False)
sample = Hdf5Sample("tmp/prof.h5", mode="w") sample = Hdf5Sample("tmp/prof.h5", mode="w")
extractor.extract_after_load_features(instance, solver, sample)
lp_stats = solver.solve_lp(tee=True)
extractor.extract_after_lp_features(solver, sample, lp_stats)
def run() -> None:
extractor.extract_after_load_features(instance, solver, sample)
extractor.extract_after_lp_features(solver, sample, lp_stats)
cProfile.run("run()", filename="tmp/prof") if __name__ == "__main__":
cProfile.run("main()", filename="tmp/prof")
os.system("flameprof tmp/prof > tmp/prof.svg") os.system("flameprof tmp/prof > tmp/prof.svg")