From 4d5b7e971c013f5ba22cbd5c24f20a924f0feb6a Mon Sep 17 00:00:00 2001 From: "Alinson S. Xavier" Date: Thu, 1 Feb 2024 13:13:10 -0600 Subject: [PATCH] Minor fixes --- src/solvers/jump.jl | 9 +++++++++ test/src/components/test_cuts.jl | 4 +--- test/src/problems/test_stab.jl | 5 ----- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/src/solvers/jump.jl b/src/solvers/jump.jl index b7bb2b2..9a84d02 100644 --- a/src/solvers/jump.jl +++ b/src/solvers/jump.jl @@ -17,6 +17,15 @@ Base.@kwdef mutable struct _JumpModelExtData cuts_separate::Union{Function,Nothing} = nothing end +function JuMP.copy_extension_data( + ::_JumpModelExtData, + new_model::AbstractModel, + ::AbstractModel, +) + # Do not transfer any extension data to the new model + new_model.ext[:miplearn] = _JumpModelExtData() +end + # ----------------------------------------------------------------------------- function _add_constrs( diff --git a/test/src/components/test_cuts.jl b/test/src/components/test_cuts.jl index 1562648..f466732 100644 --- a/test/src/components/test_cuts.jl +++ b/test/src/components/test_cuts.jl @@ -28,15 +28,13 @@ end function test_cuts() data_filenames = ["$BASEDIR/../fixtures/stab-n50-0000$i.pkl.gz" for i in 0:0] - clf = pyimport("sklearn.neighbors").KNeighborsClassifier(n_neighbors=1) + clf = pyimport("sklearn.dummy").DummyClassifier() extractor = H5FieldsExtractor( instance_fields=["static_var_obj_coeffs"], ) comp = MemorizingCutsComponent(clf=clf, extractor=extractor) solver = LearningSolver(components=[comp]) solver.fit(data_filenames) - @show comp.n_features_ - @show comp.n_targets_ stats = solver.optimize( data_filenames[1], data -> build_stab_model_jump(data, optimizer=SCIP.Optimizer), diff --git a/test/src/problems/test_stab.jl b/test/src/problems/test_stab.jl index 60f27b3..7a11a56 100644 --- a/test/src/problems/test_stab.jl +++ b/test/src/problems/test_stab.jl @@ -6,11 +6,6 @@ using PyCall using SCIP function test_problems_stab() - test_problems_stab_1() - test_problems_stab_2() -end - -function test_problems_stab_1() nx = pyimport("networkx") data = MaxWeightStableSetData( graph=nx.gnp_random_graph(25, 0.5, seed=42),