mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
BasicCollector: Do not crash on exception
This commit is contained in:
@@ -9,6 +9,7 @@ import sys
|
|||||||
from io import StringIO
|
from io import StringIO
|
||||||
from os.path import exists
|
from os.path import exists
|
||||||
from typing import Callable, List, Any
|
from typing import Callable, List, Any
|
||||||
|
import traceback
|
||||||
|
|
||||||
from ..h5 import H5File
|
from ..h5 import H5File
|
||||||
from ..io import _RedirectOutput, gzip, _to_h5_filename
|
from ..io import _RedirectOutput, gzip, _to_h5_filename
|
||||||
@@ -29,52 +30,57 @@ class BasicCollector:
|
|||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
def _collect(data_filename: str) -> None:
|
def _collect(data_filename: str) -> None:
|
||||||
h5_filename = _to_h5_filename(data_filename)
|
try:
|
||||||
mps_filename = h5_filename.replace(".h5", ".mps")
|
h5_filename = _to_h5_filename(data_filename)
|
||||||
|
mps_filename = h5_filename.replace(".h5", ".mps")
|
||||||
|
|
||||||
if exists(h5_filename):
|
if exists(h5_filename):
|
||||||
# Try to read optimal solution
|
# Try to read optimal solution
|
||||||
mip_var_values = None
|
mip_var_values = None
|
||||||
try:
|
try:
|
||||||
with H5File(h5_filename, "r") as h5:
|
with H5File(h5_filename, "r") as h5:
|
||||||
mip_var_values = h5.get_array("mip_var_values")
|
mip_var_values = h5.get_array("mip_var_values")
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if mip_var_values is None:
|
if mip_var_values is None:
|
||||||
print(f"Removing empty/corrupted h5 file: {h5_filename}")
|
print(f"Removing empty/corrupted h5 file: {h5_filename}")
|
||||||
os.remove(h5_filename)
|
os.remove(h5_filename)
|
||||||
else:
|
else:
|
||||||
return
|
return
|
||||||
|
|
||||||
with H5File(h5_filename, "w") as h5:
|
with H5File(h5_filename, "w") as h5:
|
||||||
streams: List[Any] = [StringIO()]
|
streams: List[Any] = [StringIO()]
|
||||||
if verbose:
|
if verbose:
|
||||||
streams += [sys.stdout]
|
streams += [sys.stdout]
|
||||||
with _RedirectOutput(streams):
|
with _RedirectOutput(streams):
|
||||||
# Load and extract static features
|
# Load and extract static features
|
||||||
model = build_model(data_filename)
|
model = build_model(data_filename)
|
||||||
model.extract_after_load(h5)
|
model.extract_after_load(h5)
|
||||||
|
|
||||||
if not self.skip_lp:
|
if not self.skip_lp:
|
||||||
# Solve LP relaxation
|
# Solve LP relaxation
|
||||||
relaxed = model.relax()
|
relaxed = model.relax()
|
||||||
relaxed.optimize()
|
relaxed.optimize()
|
||||||
relaxed.extract_after_lp(h5)
|
relaxed.extract_after_lp(h5)
|
||||||
|
|
||||||
# Solve MIP
|
# Solve MIP
|
||||||
model.optimize()
|
model.optimize()
|
||||||
model.extract_after_mip(h5)
|
model.extract_after_mip(h5)
|
||||||
|
|
||||||
if self.write_mps:
|
if self.write_mps:
|
||||||
# Add lazy constraints to model
|
# Add lazy constraints to model
|
||||||
model._lazy_enforce_collected()
|
model._lazy_enforce_collected()
|
||||||
|
|
||||||
# Save MPS file
|
# Save MPS file
|
||||||
model.write(mps_filename)
|
model.write(mps_filename)
|
||||||
gzip(mps_filename)
|
gzip(mps_filename)
|
||||||
|
|
||||||
|
h5.put_scalar("mip_log", streams[0].getvalue())
|
||||||
|
except:
|
||||||
|
print(f"Error processing: data_filename")
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
h5.put_scalar("mip_log", streams[0].getvalue())
|
|
||||||
|
|
||||||
if n_jobs > 1:
|
if n_jobs > 1:
|
||||||
p_umap(
|
p_umap(
|
||||||
|
|||||||
@@ -87,7 +87,10 @@ def read_pkl_gz(filename: str) -> Any:
|
|||||||
def _to_h5_filename(data_filename: str) -> str:
|
def _to_h5_filename(data_filename: str) -> str:
|
||||||
output = f"{data_filename}.h5"
|
output = f"{data_filename}.h5"
|
||||||
output = output.replace(".gz.h5", ".h5")
|
output = output.replace(".gz.h5", ".h5")
|
||||||
output = output.replace(".json.h5", ".h5")
|
output = output.replace(".csv.h5", ".h5")
|
||||||
output = output.replace(".pkl.h5", ".h5")
|
|
||||||
output = output.replace(".jld2.h5", ".h5")
|
output = output.replace(".jld2.h5", ".h5")
|
||||||
|
output = output.replace(".json.h5", ".h5")
|
||||||
|
output = output.replace(".lp.h5", ".h5")
|
||||||
|
output = output.replace(".mps.h5", ".h5")
|
||||||
|
output = output.replace(".pkl.h5", ".h5")
|
||||||
return output
|
return output
|
||||||
|
|||||||
Reference in New Issue
Block a user