Move collected data to instance.training_data

This commit is contained in:
2021-01-21 08:21:40 -06:00
parent 23dd311d75
commit 06402516e6
11 changed files with 97 additions and 89 deletions

View File

@@ -115,11 +115,11 @@ class LearningSolver:
def solve(
self,
instance,
model=None,
output="",
tee=False,
):
instance: Union[Instance, str],
model: Any = None,
output: str = "",
tee: bool = False,
) -> MIPSolveStats:
"""
Solves the given instance. If trained machine-learning models are
available, they will be used to accelerate the solution process.
@@ -127,20 +127,9 @@ class LearningSolver:
The argument `instance` may be either an Instance object or a
filename pointing to a pickled Instance object.
This method modifies the instance object. Specifically, the following
properties are set:
- instance.lp_solution
- instance.lp_value
- instance.lower_bound
- instance.upper_bound
- instance.solution
- instance.solver_log
Additional solver components may set additional properties. Please
see their documentation for more details. If a filename is provided,
then the file is modified in-place. That is, the original file is
overwritten.
This method adds a new training sample to `instance.training_sample`.
If a filename is provided, then the file is modified in-place. That is,
the original file is overwritten.
If `solver.solve_lp_first` is False, the properties lp_solution and
lp_value will be set to dummy values.
@@ -190,7 +179,7 @@ class LearningSolver:
def _solve(
self,
instance: Instance,
instance: Union[Instance, str],
model: Any = None,
output: str = "",
tee: bool = False,
@@ -211,14 +200,18 @@ class LearningSolver:
fileformat = "pickle"
with open(filename, "rb") as file:
instance = pickle.load(cast(IO[bytes], file))
assert isinstance(instance, Instance)
# Generate model
if model is None:
with RedirectOutput([]):
model = instance.to_model()
# Initialize training data
# Initialize training sample
training_sample: TrainingSample = {}
if not hasattr(instance, "training_data"):
instance.training_data = []
instance.training_data += [training_sample]
# Initialize internal solver
self.tee = tee
@@ -275,11 +268,6 @@ class LearningSolver:
for component in self.components.values():
component.after_solve(self, instance, model, stats, training_sample)
# Append training data
if not hasattr(instance, "training_data"):
instance.training_data = []
instance.training_data += [training_sample]
# Write to file, if necessary
if filename is not None and output is not None:
output_filename = output
@@ -350,7 +338,7 @@ class LearningSolver:
self._restore_miplearn_logger()
return stats
def fit(self, training_instances):
def fit(self, training_instances: Union[List[str], List[Instance]]) -> None:
if len(training_instances) == 0:
return
for component in self.components.values():