Re-enable half-precision; minor changes to FeaturesExtractor benchmark

master
Alinson S. Xavier 4 years ago
parent 9cfb31bacb
commit a65ebfb17c

@ -174,6 +174,8 @@ class Hdf5Sample(Sample):
if value is None: if value is None:
return return
self._assert_is_array(value) self._assert_is_array(value)
if len(value.shape) > 1 and value.dtype.kind == "f":
value = value.astype("float16")
if key in self.file: if key in self.file:
del self.file[key] del self.file[key]
return self.file.create_dataset(key, data=value, compression="gzip") return self.file.create_dataset(key, data=value, compression="gzip")

@ -11,7 +11,7 @@ import numpy as np
import gurobipy as gp import gurobipy as gp
from miplearn.features.extractor import FeaturesExtractor from miplearn.features.extractor import FeaturesExtractor
from miplearn.features.sample import MemorySample, Hdf5Sample from miplearn.features.sample import Hdf5Sample
from miplearn.instance.base import Instance from miplearn.instance.base import Instance
from miplearn.solvers.gurobi import GurobiSolver from miplearn.solvers.gurobi import GurobiSolver
from miplearn.solvers.internal import Variables, Constraints from miplearn.solvers.internal import Variables, Constraints
@ -382,17 +382,17 @@ class MpsInstance(Instance):
return gp.read(self.filename) return gp.read(self.filename)
if __name__ == "__main__": def main() -> None:
solver = GurobiSolver() solver = GurobiSolver()
instance = MpsInstance(sys.argv[1]) instance = MpsInstance(sys.argv[1])
solver.set_instance(instance) solver.set_instance(instance)
lp_stats = solver.solve_lp(tee=True)
extractor = FeaturesExtractor(with_lhs=False) extractor = FeaturesExtractor(with_lhs=False)
sample = Hdf5Sample("tmp/prof.h5", mode="w") sample = Hdf5Sample("tmp/prof.h5", mode="w")
extractor.extract_after_load_features(instance, solver, sample)
lp_stats = solver.solve_lp(tee=True)
extractor.extract_after_lp_features(solver, sample, lp_stats)
def run() -> None:
extractor.extract_after_load_features(instance, solver, sample)
extractor.extract_after_lp_features(solver, sample, lp_stats)
cProfile.run("run()", filename="tmp/prof") if __name__ == "__main__":
cProfile.run("main()", filename="tmp/prof")
os.system("flameprof tmp/prof > tmp/prof.svg") os.system("flameprof tmp/prof > tmp/prof.svg")

Loading…
Cancel
Save