mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 09:28:51 -06:00
Move collected data to instance.training_data
This commit is contained in:
@@ -115,11 +115,11 @@ class LearningSolver:
|
||||
|
||||
def solve(
|
||||
self,
|
||||
instance,
|
||||
model=None,
|
||||
output="",
|
||||
tee=False,
|
||||
):
|
||||
instance: Union[Instance, str],
|
||||
model: Any = None,
|
||||
output: str = "",
|
||||
tee: bool = False,
|
||||
) -> MIPSolveStats:
|
||||
"""
|
||||
Solves the given instance. If trained machine-learning models are
|
||||
available, they will be used to accelerate the solution process.
|
||||
@@ -127,20 +127,9 @@ class LearningSolver:
|
||||
The argument `instance` may be either an Instance object or a
|
||||
filename pointing to a pickled Instance object.
|
||||
|
||||
This method modifies the instance object. Specifically, the following
|
||||
properties are set:
|
||||
|
||||
- instance.lp_solution
|
||||
- instance.lp_value
|
||||
- instance.lower_bound
|
||||
- instance.upper_bound
|
||||
- instance.solution
|
||||
- instance.solver_log
|
||||
|
||||
Additional solver components may set additional properties. Please
|
||||
see their documentation for more details. If a filename is provided,
|
||||
then the file is modified in-place. That is, the original file is
|
||||
overwritten.
|
||||
This method adds a new training sample to `instance.training_sample`.
|
||||
If a filename is provided, then the file is modified in-place. That is,
|
||||
the original file is overwritten.
|
||||
|
||||
If `solver.solve_lp_first` is False, the properties lp_solution and
|
||||
lp_value will be set to dummy values.
|
||||
@@ -190,7 +179,7 @@ class LearningSolver:
|
||||
|
||||
def _solve(
|
||||
self,
|
||||
instance: Instance,
|
||||
instance: Union[Instance, str],
|
||||
model: Any = None,
|
||||
output: str = "",
|
||||
tee: bool = False,
|
||||
@@ -211,14 +200,18 @@ class LearningSolver:
|
||||
fileformat = "pickle"
|
||||
with open(filename, "rb") as file:
|
||||
instance = pickle.load(cast(IO[bytes], file))
|
||||
assert isinstance(instance, Instance)
|
||||
|
||||
# Generate model
|
||||
if model is None:
|
||||
with RedirectOutput([]):
|
||||
model = instance.to_model()
|
||||
|
||||
# Initialize training data
|
||||
# Initialize training sample
|
||||
training_sample: TrainingSample = {}
|
||||
if not hasattr(instance, "training_data"):
|
||||
instance.training_data = []
|
||||
instance.training_data += [training_sample]
|
||||
|
||||
# Initialize internal solver
|
||||
self.tee = tee
|
||||
@@ -275,11 +268,6 @@ class LearningSolver:
|
||||
for component in self.components.values():
|
||||
component.after_solve(self, instance, model, stats, training_sample)
|
||||
|
||||
# Append training data
|
||||
if not hasattr(instance, "training_data"):
|
||||
instance.training_data = []
|
||||
instance.training_data += [training_sample]
|
||||
|
||||
# Write to file, if necessary
|
||||
if filename is not None and output is not None:
|
||||
output_filename = output
|
||||
@@ -350,7 +338,7 @@ class LearningSolver:
|
||||
self._restore_miplearn_logger()
|
||||
return stats
|
||||
|
||||
def fit(self, training_instances):
|
||||
def fit(self, training_instances: Union[List[str], List[Instance]]) -> None:
|
||||
if len(training_instances) == 0:
|
||||
return
|
||||
for component in self.components.values():
|
||||
|
||||
@@ -25,20 +25,19 @@ def test_learning_solver():
|
||||
)
|
||||
|
||||
solver.solve(instance)
|
||||
assert instance.solution["x"][0] == 1.0
|
||||
assert instance.solution["x"][1] == 0.0
|
||||
assert instance.solution["x"][2] == 1.0
|
||||
assert instance.solution["x"][3] == 1.0
|
||||
assert instance.lower_bound == 1183.0
|
||||
assert instance.upper_bound == 1183.0
|
||||
assert round(instance.lp_solution["x"][0], 3) == 1.000
|
||||
assert round(instance.lp_solution["x"][1], 3) == 0.923
|
||||
assert round(instance.lp_solution["x"][2], 3) == 1.000
|
||||
assert round(instance.lp_solution["x"][3], 3) == 0.000
|
||||
assert round(instance.lp_value, 3) == 1287.923
|
||||
assert instance.found_violated_lazy_constraints == []
|
||||
assert instance.found_violated_user_cuts == []
|
||||
assert len(instance.solver_log) > 100
|
||||
data = instance.training_data[0]
|
||||
assert data["Solution"]["x"][0] == 1.0
|
||||
assert data["Solution"]["x"][1] == 0.0
|
||||
assert data["Solution"]["x"][2] == 1.0
|
||||
assert data["Solution"]["x"][3] == 1.0
|
||||
assert data["Lower bound"] == 1183.0
|
||||
assert data["Upper bound"] == 1183.0
|
||||
assert round(data["LP solution"]["x"][0], 3) == 1.000
|
||||
assert round(data["LP solution"]["x"][1], 3) == 0.923
|
||||
assert round(data["LP solution"]["x"][2], 3) == 1.000
|
||||
assert round(data["LP solution"]["x"][3], 3) == 0.000
|
||||
assert round(data["LP value"], 3) == 1287.923
|
||||
assert len(data["MIP log"]) > 100
|
||||
|
||||
solver.fit([instance])
|
||||
solver.solve(instance)
|
||||
@@ -55,7 +54,8 @@ def test_parallel_solve():
|
||||
results = solver.parallel_solve(instances, n_jobs=3)
|
||||
assert len(results) == 10
|
||||
for instance in instances:
|
||||
assert len(instance.solution["x"].keys()) == 4
|
||||
data = instance.training_data[0]
|
||||
assert len(data["Solution"]["x"].keys()) == 4
|
||||
|
||||
|
||||
def test_solve_fit_from_disk():
|
||||
@@ -73,14 +73,14 @@ def test_solve_fit_from_disk():
|
||||
solver.solve(filenames[0])
|
||||
with open(filenames[0], "rb") as file:
|
||||
instance = pickle.load(file)
|
||||
assert hasattr(instance, "solution")
|
||||
assert len(instance.training_data) > 0
|
||||
|
||||
# Test: parallel_solve
|
||||
solver.parallel_solve(filenames)
|
||||
for filename in filenames:
|
||||
with open(filename, "rb") as file:
|
||||
instance = pickle.load(file)
|
||||
assert hasattr(instance, "solution")
|
||||
assert len(instance.training_data) > 0
|
||||
|
||||
# Test: solve (with specified output)
|
||||
output = [f + ".out" for f in filenames]
|
||||
|
||||
Reference in New Issue
Block a user