From 2979bd157c5a2c6a8ef584f1e5ceeccd88d8b68e Mon Sep 17 00:00:00 2001 From: "Alinson S. Xavier" Date: Sun, 11 Apr 2021 21:52:59 -0500 Subject: [PATCH] Rewrite PrimalSolutionComponent.sample_xy --- miplearn/components/primal.py | 46 +++++++++++++++++++- tests/components/test_primal.py | 75 ++++++++++++++++++++++++++++++++- 2 files changed, 118 insertions(+), 3 deletions(-) diff --git a/miplearn/components/primal.py b/miplearn/components/primal.py index 6179453..b0880d6 100644 --- a/miplearn/components/primal.py +++ b/miplearn/components/primal.py @@ -20,7 +20,7 @@ from miplearn.classifiers.adaptive import AdaptiveClassifier from miplearn.classifiers.threshold import MinPrecisionThreshold, Threshold from miplearn.components import classifier_evaluation_dict from miplearn.components.component import Component -from miplearn.features import TrainingSample, Features +from miplearn.features import TrainingSample, Features, Sample from miplearn.instance.base import Instance from miplearn.types import ( LearningSolveStats, @@ -179,6 +179,50 @@ class PrimalSolutionComponent(Component): y[category] += [[opt_value < 0.5, opt_value >= 0.5]] return x, y + @overrides + def sample_xy( + self, + sample: Sample, + ) -> Tuple[Dict[Category, List[List[float]]], Dict[Category, List[List[float]]]]: + x: Dict = {} + y: Dict = {} + assert sample.after_load is not None + assert sample.after_load.variables is not None + for (var_name, var) in sample.after_load.variables.items(): + # Initialize categories + category = var.category + if category is None: + continue + if category not in x.keys(): + x[category] = [] + y[category] = [] + + # Features + sf = sample.after_load + if sample.after_lp is not None: + sf = sample.after_lp + assert sf.instance is not None + features = list(sf.instance.to_list()) + assert sf.variables is not None + assert sf.variables[var_name] is not None + features.extend(sf.variables[var_name].to_list()) + x[category].append(features) + + # Labels + if sample.after_mip is not None: + assert sample.after_mip.variables is not None + assert sample.after_mip.variables[var_name] is not None + opt_value = sample.after_mip.variables[var_name].value + assert opt_value is not None + assert 0.0 - 1e-5 <= opt_value <= 1.0 + 1e-5, ( + f"Variable {var_name} has non-binary value {opt_value} in the " + "optimal solution. Predicting values of non-binary " + "variables is not currently supported. Please set its " + "category to None." + ) + y[category].append([opt_value < 0.5, opt_value >= 0.5]) + return x, y + @overrides def sample_evaluate_old( self, diff --git a/tests/components/test_primal.py b/tests/components/test_primal.py index f03cb33..9daa337 100644 --- a/tests/components/test_primal.py +++ b/tests/components/test_primal.py @@ -5,6 +5,7 @@ from typing import cast from unittest.mock import Mock import numpy as np +import pytest from numpy.testing import assert_array_equal from scipy.stats import randint @@ -12,13 +13,83 @@ from miplearn.classifiers import Classifier from miplearn.classifiers.threshold import Threshold from miplearn.components import classifier_evaluation_dict from miplearn.components.primal import PrimalSolutionComponent -from miplearn.features import TrainingSample, Variable, Features +from miplearn.features import ( + TrainingSample, + Variable, + Features, + Sample, + InstanceFeatures, +) from miplearn.instance.base import Instance from miplearn.problems.tsp import TravelingSalesmanGenerator from miplearn.solvers.learning import LearningSolver -def test_xy() -> None: +@pytest.fixture +def sample() -> Sample: + sample = Sample( + after_load=Features( + variables={ + "x[0]": Variable(category="default"), + "x[1]": Variable(category=None), + "x[2]": Variable(category="default"), + "x[3]": Variable(category="default"), + }, + ), + after_lp=Features( + instance=InstanceFeatures(), + variables={ + "x[0]": Variable(), + "x[1]": Variable(), + "x[2]": Variable(), + "x[3]": Variable(), + }, + ), + after_mip=Features( + variables={ + "x[0]": Variable(value=0.0), + "x[1]": Variable(value=0.0), + "x[2]": Variable(value=1.0), + "x[3]": Variable(value=0.0), + } + ), + ) + sample.after_lp.instance.to_list = Mock(return_value=[5.0]) # type: ignore + sample.after_lp.variables["x[0]"].to_list = Mock( # type: ignore + return_value=[0.0, 0.0] + ) + sample.after_lp.variables["x[2]"].to_list = Mock( # type: ignore + return_value=[1.0, 0.0] + ) + sample.after_lp.variables["x[3]"].to_list = Mock( # type: ignore + return_value=[1.0, 1.0] + ) + return sample + + +def test_xy(sample: Sample) -> None: + x_expected = { + "default": [ + [5.0, 0.0, 0.0], + [5.0, 1.0, 0.0], + [5.0, 1.0, 1.0], + ] + } + y_expected = { + "default": [ + [True, False], + [False, True], + [True, False], + ] + } + xy = PrimalSolutionComponent().sample_xy(sample) + assert xy is not None + x_actual, y_actual = xy + assert x_actual == x_expected + assert y_actual == y_expected + + +def test_xy_old() -> None: features = Features( variables={ "x[0]": Variable(