Minor fixes

feature/replay^2
Alinson S. Xavier 2 years ago
parent d69c4bbfa7
commit 4d5b7e971c

@ -17,6 +17,15 @@ Base.@kwdef mutable struct _JumpModelExtData
cuts_separate::Union{Function,Nothing} = nothing cuts_separate::Union{Function,Nothing} = nothing
end end
function JuMP.copy_extension_data(
::_JumpModelExtData,
new_model::AbstractModel,
::AbstractModel,
)
# Do not transfer any extension data to the new model
new_model.ext[:miplearn] = _JumpModelExtData()
end
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
function _add_constrs( function _add_constrs(

@ -28,15 +28,13 @@ end
function test_cuts() function test_cuts()
data_filenames = ["$BASEDIR/../fixtures/stab-n50-0000$i.pkl.gz" for i in 0:0] data_filenames = ["$BASEDIR/../fixtures/stab-n50-0000$i.pkl.gz" for i in 0:0]
clf = pyimport("sklearn.neighbors").KNeighborsClassifier(n_neighbors=1) clf = pyimport("sklearn.dummy").DummyClassifier()
extractor = H5FieldsExtractor( extractor = H5FieldsExtractor(
instance_fields=["static_var_obj_coeffs"], instance_fields=["static_var_obj_coeffs"],
) )
comp = MemorizingCutsComponent(clf=clf, extractor=extractor) comp = MemorizingCutsComponent(clf=clf, extractor=extractor)
solver = LearningSolver(components=[comp]) solver = LearningSolver(components=[comp])
solver.fit(data_filenames) solver.fit(data_filenames)
@show comp.n_features_
@show comp.n_targets_
stats = solver.optimize( stats = solver.optimize(
data_filenames[1], data_filenames[1],
data -> build_stab_model_jump(data, optimizer=SCIP.Optimizer), data -> build_stab_model_jump(data, optimizer=SCIP.Optimizer),

@ -6,11 +6,6 @@ using PyCall
using SCIP using SCIP
function test_problems_stab() function test_problems_stab()
test_problems_stab_1()
test_problems_stab_2()
end
function test_problems_stab_1()
nx = pyimport("networkx") nx = pyimport("networkx")
data = MaxWeightStableSetData( data = MaxWeightStableSetData(
graph=nx.gnp_random_graph(25, 0.5, seed=42), graph=nx.gnp_random_graph(25, 0.5, seed=42),

Loading…
Cancel
Save