diff --git a/src/solvers/jump.jl b/src/solvers/jump.jl index 6c45a4e..163637e 100644 --- a/src/solvers/jump.jl +++ b/src/solvers/jump.jl @@ -372,6 +372,13 @@ function _set_warm_starts(model::JuMP.Model, var_names, var_values, stats) end function _write(model::JuMP.Model, filename) + ext = model.ext[:miplearn] + if ext.lazy_separate !== nothing + set_attribute(model, MOI.LazyConstraintCallback(), nothing) + end + if ext.cuts_separate !== nothing + set_attribute(model, MOI.UserCutCallback(), nothing) + end write_to_file(model, filename) end @@ -439,6 +446,13 @@ function __init_solvers_jump__() function lazy_enforce(self, model, violations) self.inner.ext[:miplearn].lazy_enforce(violations) end + + function _lazy_enforce_collected(self) + ext = self.inner.ext[:miplearn] + if ext.lazy_enforce !== nothing + ext.lazy_enforce(ext.lazy) + end + end end copy!(JumpModel, Class) end diff --git a/test/src/components/test_cuts.jl b/test/src/components/test_cuts.jl index f9cccd9..4b1bb93 100644 --- a/test/src/components/test_cuts.jl +++ b/test/src/components/test_cuts.jl @@ -17,7 +17,7 @@ function gen_stab() ) data = gen.generate(1) data_filenames = write_pkl_gz(data, "$BASEDIR/../fixtures", prefix="stab-n50-") - collector = BasicCollector(write_mps=false) + collector = BasicCollector() collector.collect( data_filenames, data -> build_stab_model_jump(data, optimizer=SCIP.Optimizer), diff --git a/test/src/components/test_lazy.jl b/test/src/components/test_lazy.jl index 291b5fa..b62c1d8 100644 --- a/test/src/components/test_lazy.jl +++ b/test/src/components/test_lazy.jl @@ -20,7 +20,7 @@ function gen_tsp() ) data = gen.generate(1) data_filenames = write_pkl_gz(data, "$BASEDIR/../fixtures", prefix="tsp-n20-") - collector = BasicCollector(write_mps=false) + collector = BasicCollector() collector.collect( data_filenames, data -> build_tsp_model_jump(data, optimizer=GLPK.Optimizer), diff --git a/test/src/test_usage.jl b/test/src/test_usage.jl index 3e41332..1965e0e 100644 --- a/test/src/test_usage.jl +++ b/test/src/test_usage.jl @@ -28,7 +28,7 @@ function test_usage() ) @debug "Collecting training data..." - bc = BasicCollector(write_mps=false) + bc = BasicCollector() bc.collect(data_filenames, build_setcover_model_jump) @debug "Training models..."