diff --git a/miplearn/components/steps/convert_tight.py b/miplearn/components/steps/convert_tight.py index 013ff59..7adc84d 100644 --- a/miplearn/components/steps/convert_tight.py +++ b/miplearn/components/steps/convert_tight.py @@ -5,6 +5,7 @@ import logging from copy import deepcopy +import numpy as np from tqdm import tqdm from miplearn import Component @@ -124,7 +125,7 @@ class ConvertTightIneqsIntoEqsStep(Component): if category not in self.classifiers: continue y[category] = [] - # x_cat = np.array(x_cat) + x_cat = np.array(x_cat) proba = self.classifiers[category].predict_proba(x_cat) for i in range(len(proba)): if proba[i][1] >= self.threshold: diff --git a/miplearn/components/steps/tests/__init__.py b/miplearn/components/steps/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/miplearn/components/steps/tests/convert_tight_test.py b/miplearn/components/steps/tests/convert_tight_test.py new file mode 100644 index 0000000..d64bba0 --- /dev/null +++ b/miplearn/components/steps/tests/convert_tight_test.py @@ -0,0 +1,34 @@ +from miplearn import LearningSolver, GurobiSolver +from miplearn.components.steps.convert_tight import ConvertTightIneqsIntoEqsStep +from miplearn.components.steps.relax_integrality import RelaxIntegralityStep +from miplearn.problems.knapsack import GurobiKnapsackInstance + + +def test_convert_tight_usage(): + instance = GurobiKnapsackInstance( + weights=[3.0, 5.0, 10.0], + prices=[1.0, 1.0, 1.0], + capacity=16.0, + ) + solver = LearningSolver( + solver=GurobiSolver(), + components=[ + RelaxIntegralityStep(), + ConvertTightIneqsIntoEqsStep(), + ], + ) + + # Solve original problem + solver.solve(instance) + original_upper_bound = instance.upper_bound + + # Should collect training data + assert hasattr(instance, "slacks") + assert instance.slacks["eq_capacity"] == 0.0 + + # Fit and resolve + solver.fit([instance]) + solver.solve(instance) + + # Objective value should be the same + assert instance.upper_bound == original_upper_bound diff --git a/miplearn/problems/stab.py b/miplearn/problems/stab.py index 3d4a285..03ea558 100644 --- a/miplearn/problems/stab.py +++ b/miplearn/problems/stab.py @@ -2,14 +2,14 @@ # Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved. # Released under the modified BSD license. See COPYING.md for more details. +import networkx as nx import numpy as np import pyomo.environ as pe -import networkx as nx -from miplearn import Instance -import random -from scipy.stats import uniform, randint, bernoulli +from scipy.stats import uniform, randint from scipy.stats.distributions import rv_frozen +from miplearn import Instance + class ChallengeA: def __init__(