From d30c3232e6c4d51886da954b537c69091db7b866 Mon Sep 17 00:00:00 2001 From: "Alinson S. Xavier" Date: Tue, 27 Jul 2021 11:22:40 -0500 Subject: [PATCH] FileInstance.save: create file when it does not already exist --- miplearn/features/sample.py | 4 ++-- miplearn/instance/file.py | 2 +- tests/instance/test_file.py | 10 +++++----- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/miplearn/features/sample.py b/miplearn/features/sample.py index e75f42e..5c19172 100644 --- a/miplearn/features/sample.py +++ b/miplearn/features/sample.py @@ -167,8 +167,8 @@ class Hdf5Sample(Sample): are actually accessed, and therefore it is more scalable. """ - def __init__(self, filename: str) -> None: - self.file = h5py.File(filename, "r+") + def __init__(self, filename: str, mode: str = "r+") -> None: + self.file = h5py.File(filename, mode) @overrides def get_bytes(self, key: str) -> Optional[bytes]: diff --git a/miplearn/instance/file.py b/miplearn/instance/file.py index f0e5664..14f9fdf 100644 --- a/miplearn/instance/file.py +++ b/miplearn/instance/file.py @@ -119,7 +119,7 @@ class FileInstance(Instance): @classmethod def save(cls, instance: Instance, filename: str) -> None: - h5 = Hdf5Sample(filename) + h5 = Hdf5Sample(filename, mode="w") instance_pkl = pickle.dumps(instance) h5.put_bytes("pickled", instance_pkl) diff --git a/tests/instance/test_file.py b/tests/instance/test_file.py index 14c7c73..6ff9767 100644 --- a/tests/instance/test_file.py +++ b/tests/instance/test_file.py @@ -15,17 +15,17 @@ def test_usage() -> None: original = GurobiSolver().build_test_instance_knapsack() # Save instance to disk - file = tempfile.NamedTemporaryFile() - FileInstance.save(original, file.name) - sample = Hdf5Sample(file.name) + filename = tempfile.mktemp() + FileInstance.save(original, filename) + sample = Hdf5Sample(filename) assert len(sample.get_bytes("pickled")) > 0 # Solve instance from disk solver = LearningSolver(solver=GurobiSolver()) - solver.solve(FileInstance(file.name)) + solver.solve(FileInstance(filename)) # Assert HDF5 contains training data - sample = FileInstance(file.name).get_samples()[0] + sample = FileInstance(filename).get_samples()[0] assert sample.get_scalar("mip_lower_bound") == 1183.0 assert sample.get_scalar("mip_upper_bound") == 1183.0 assert len(sample.get_vector("lp_var_values")) == 5