FileInstance.save: create file when it does not already exist

master
Alinson S. Xavier 4 years ago
parent 4f14b99a75
commit d30c3232e6
No known key found for this signature in database
GPG Key ID: DCA0DAD4D2F58624

@ -167,8 +167,8 @@ class Hdf5Sample(Sample):
are actually accessed, and therefore it is more scalable.
"""
def __init__(self, filename: str) -> None:
self.file = h5py.File(filename, "r+")
def __init__(self, filename: str, mode: str = "r+") -> None:
self.file = h5py.File(filename, mode)
@overrides
def get_bytes(self, key: str) -> Optional[bytes]:

@ -119,7 +119,7 @@ class FileInstance(Instance):
@classmethod
def save(cls, instance: Instance, filename: str) -> None:
h5 = Hdf5Sample(filename)
h5 = Hdf5Sample(filename, mode="w")
instance_pkl = pickle.dumps(instance)
h5.put_bytes("pickled", instance_pkl)

@ -15,17 +15,17 @@ def test_usage() -> None:
original = GurobiSolver().build_test_instance_knapsack()
# Save instance to disk
file = tempfile.NamedTemporaryFile()
FileInstance.save(original, file.name)
sample = Hdf5Sample(file.name)
filename = tempfile.mktemp()
FileInstance.save(original, filename)
sample = Hdf5Sample(filename)
assert len(sample.get_bytes("pickled")) > 0
# Solve instance from disk
solver = LearningSolver(solver=GurobiSolver())
solver.solve(FileInstance(file.name))
solver.solve(FileInstance(filename))
# Assert HDF5 contains training data
sample = FileInstance(file.name).get_samples()[0]
sample = FileInstance(filename).get_samples()[0]
assert sample.get_scalar("mip_lower_bound") == 1183.0
assert sample.get_scalar("mip_upper_bound") == 1183.0
assert len(sample.get_vector("lp_var_values")) == 5

Loading…
Cancel
Save