diff --git a/miplearn/features/sample.py b/miplearn/features/sample.py index e75f42e..5c19172 100644 --- a/miplearn/features/sample.py +++ b/miplearn/features/sample.py @@ -167,8 +167,8 @@ class Hdf5Sample(Sample): are actually accessed, and therefore it is more scalable. """ - def __init__(self, filename: str) -> None: - self.file = h5py.File(filename, "r+") + def __init__(self, filename: str, mode: str = "r+") -> None: + self.file = h5py.File(filename, mode) @overrides def get_bytes(self, key: str) -> Optional[bytes]: diff --git a/miplearn/instance/file.py b/miplearn/instance/file.py index f0e5664..14f9fdf 100644 --- a/miplearn/instance/file.py +++ b/miplearn/instance/file.py @@ -119,7 +119,7 @@ class FileInstance(Instance): @classmethod def save(cls, instance: Instance, filename: str) -> None: - h5 = Hdf5Sample(filename) + h5 = Hdf5Sample(filename, mode="w") instance_pkl = pickle.dumps(instance) h5.put_bytes("pickled", instance_pkl) diff --git a/tests/instance/test_file.py b/tests/instance/test_file.py index 14c7c73..6ff9767 100644 --- a/tests/instance/test_file.py +++ b/tests/instance/test_file.py @@ -15,17 +15,17 @@ def test_usage() -> None: original = GurobiSolver().build_test_instance_knapsack() # Save instance to disk - file = tempfile.NamedTemporaryFile() - FileInstance.save(original, file.name) - sample = Hdf5Sample(file.name) + filename = tempfile.mktemp() + FileInstance.save(original, filename) + sample = Hdf5Sample(filename) assert len(sample.get_bytes("pickled")) > 0 # Solve instance from disk solver = LearningSolver(solver=GurobiSolver()) - solver.solve(FileInstance(file.name)) + solver.solve(FileInstance(filename)) # Assert HDF5 contains training data - sample = FileInstance(file.name).get_samples()[0] + sample = FileInstance(filename).get_samples()[0] assert sample.get_scalar("mip_lower_bound") == 1183.0 assert sample.get_scalar("mip_upper_bound") == 1183.0 assert len(sample.get_vector("lp_var_values")) == 5