mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
FileInstance.save: create file when it does not already exist
This commit is contained in:
@@ -167,8 +167,8 @@ class Hdf5Sample(Sample):
|
|||||||
are actually accessed, and therefore it is more scalable.
|
are actually accessed, and therefore it is more scalable.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, filename: str) -> None:
|
def __init__(self, filename: str, mode: str = "r+") -> None:
|
||||||
self.file = h5py.File(filename, "r+")
|
self.file = h5py.File(filename, mode)
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
def get_bytes(self, key: str) -> Optional[bytes]:
|
def get_bytes(self, key: str) -> Optional[bytes]:
|
||||||
|
|||||||
@@ -119,7 +119,7 @@ class FileInstance(Instance):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def save(cls, instance: Instance, filename: str) -> None:
|
def save(cls, instance: Instance, filename: str) -> None:
|
||||||
h5 = Hdf5Sample(filename)
|
h5 = Hdf5Sample(filename, mode="w")
|
||||||
instance_pkl = pickle.dumps(instance)
|
instance_pkl = pickle.dumps(instance)
|
||||||
h5.put_bytes("pickled", instance_pkl)
|
h5.put_bytes("pickled", instance_pkl)
|
||||||
|
|
||||||
|
|||||||
@@ -15,17 +15,17 @@ def test_usage() -> None:
|
|||||||
original = GurobiSolver().build_test_instance_knapsack()
|
original = GurobiSolver().build_test_instance_knapsack()
|
||||||
|
|
||||||
# Save instance to disk
|
# Save instance to disk
|
||||||
file = tempfile.NamedTemporaryFile()
|
filename = tempfile.mktemp()
|
||||||
FileInstance.save(original, file.name)
|
FileInstance.save(original, filename)
|
||||||
sample = Hdf5Sample(file.name)
|
sample = Hdf5Sample(filename)
|
||||||
assert len(sample.get_bytes("pickled")) > 0
|
assert len(sample.get_bytes("pickled")) > 0
|
||||||
|
|
||||||
# Solve instance from disk
|
# Solve instance from disk
|
||||||
solver = LearningSolver(solver=GurobiSolver())
|
solver = LearningSolver(solver=GurobiSolver())
|
||||||
solver.solve(FileInstance(file.name))
|
solver.solve(FileInstance(filename))
|
||||||
|
|
||||||
# Assert HDF5 contains training data
|
# Assert HDF5 contains training data
|
||||||
sample = FileInstance(file.name).get_samples()[0]
|
sample = FileInstance(filename).get_samples()[0]
|
||||||
assert sample.get_scalar("mip_lower_bound") == 1183.0
|
assert sample.get_scalar("mip_lower_bound") == 1183.0
|
||||||
assert sample.get_scalar("mip_upper_bound") == 1183.0
|
assert sample.get_scalar("mip_upper_bound") == 1183.0
|
||||||
assert len(sample.get_vector("lp_var_values")) == 5
|
assert len(sample.get_vector("lp_var_values")) == 5
|
||||||
|
|||||||
Reference in New Issue
Block a user