Add Component.xy and PrimalSolutionComponent.xy

This commit is contained in:
2021-03-30 17:08:10 -05:00
parent 75d1eee424
commit 9266743940
3 changed files with 172 additions and 2 deletions

View File

@@ -3,7 +3,7 @@
# Released under the modified BSD license. See COPYING.md for more details.
from abc import ABC, abstractmethod
from typing import Any, List, Union, TYPE_CHECKING
from typing import Any, List, Union, TYPE_CHECKING, Tuple, Dict
from miplearn.instance import Instance
from miplearn.types import LearningSolveStats, TrainingSample
@@ -12,6 +12,7 @@ if TYPE_CHECKING:
from miplearn.solvers.learning import LearningSolver
# noinspection PyMethodMayBeStatic
class Component(ABC):
"""
A Component is an object which adds functionality to a LearningSolver.
@@ -135,6 +136,17 @@ class Component(ABC):
) -> None:
return
def xy(
self,
instance: Any,
training_sample: TrainingSample,
) -> Tuple[Dict, Dict]:
"""
Given a training sample, returns a pair of x and y dictionaries containing,
respectively, the matrices of ML features and the labels for the sample.
"""
return {}, {}
def iteration_cb(
self,
solver: "LearningSolver",

View File

@@ -3,7 +3,17 @@
# Released under the modified BSD license. See COPYING.md for more details.
import logging
from typing import Union, Dict, Callable, List, Hashable, Optional, Any, TYPE_CHECKING
from typing import (
Union,
Dict,
Callable,
List,
Hashable,
Optional,
Any,
TYPE_CHECKING,
Tuple,
)
import numpy as np
from tqdm.auto import tqdm
@@ -286,3 +296,34 @@ class PrimalSolutionComponent(Component):
f"Please set its category to None."
)
return [opt_value < 0.5, opt_value > 0.5]
def xy(
self,
instance: Any,
sample: TrainingSample,
) -> Tuple[Dict, Dict]:
x: Dict = {}
y: Dict = {}
if "Solution" not in sample:
return x, y
assert sample["Solution"] is not None
for (var, var_dict) in sample["Solution"].items():
for (idx, opt_value) in var_dict.items():
assert opt_value is not None
assert 0.0 - 1e-5 <= opt_value <= 1.0 + 1e-5, (
f"Variable {var} has non-binary value {opt_value} in the optimal "
f"solution. Predicting values of non-binary variables is not "
f"currently supported. Please set its category to None."
)
category = instance.get_variable_category(var, idx)
if category is None:
continue
if category not in x.keys():
x[category] = []
y[category] = []
features: Any = instance.get_variable_features(var, idx)
if "LP solution" in sample and sample["LP solution"] is not None:
features += [sample["LP solution"][var][idx]]
x[category] += [features]
y[category] += [[opt_value < 0.5, opt_value >= 0.5]]
return x, y