Implement tests for ConvertTightIneqsIntoEqsStep

This commit is contained in:
2021-01-07 10:29:22 -06:00
parent 0377b5b546
commit d8dc8471aa
4 changed files with 40 additions and 5 deletions

View File

@@ -5,6 +5,7 @@
import logging
from copy import deepcopy
import numpy as np
from tqdm import tqdm
from miplearn import Component
@@ -124,7 +125,7 @@ class ConvertTightIneqsIntoEqsStep(Component):
if category not in self.classifiers:
continue
y[category] = []
# x_cat = np.array(x_cat)
x_cat = np.array(x_cat)
proba = self.classifiers[category].predict_proba(x_cat)
for i in range(len(proba)):
if proba[i][1] >= self.threshold:

View File

@@ -0,0 +1,34 @@
from miplearn import LearningSolver, GurobiSolver
from miplearn.components.steps.convert_tight import ConvertTightIneqsIntoEqsStep
from miplearn.components.steps.relax_integrality import RelaxIntegralityStep
from miplearn.problems.knapsack import GurobiKnapsackInstance
def test_convert_tight_usage():
instance = GurobiKnapsackInstance(
weights=[3.0, 5.0, 10.0],
prices=[1.0, 1.0, 1.0],
capacity=16.0,
)
solver = LearningSolver(
solver=GurobiSolver(),
components=[
RelaxIntegralityStep(),
ConvertTightIneqsIntoEqsStep(),
],
)
# Solve original problem
solver.solve(instance)
original_upper_bound = instance.upper_bound
# Should collect training data
assert hasattr(instance, "slacks")
assert instance.slacks["eq_capacity"] == 0.0
# Fit and resolve
solver.fit([instance])
solver.solve(instance)
# Objective value should be the same
assert instance.upper_bound == original_upper_bound

View File

@@ -2,14 +2,14 @@
# Copyright (C) 2020, UChicago Argonne, LLC. All rights reserved.
# Released under the modified BSD license. See COPYING.md for more details.
import networkx as nx
import numpy as np
import pyomo.environ as pe
import networkx as nx
from miplearn import Instance
import random
from scipy.stats import uniform, randint, bernoulli
from scipy.stats import uniform, randint
from scipy.stats.distributions import rv_frozen
from miplearn import Instance
class ChallengeA:
def __init__(