BasicCollector: Do not crash on exception

This commit is contained in:
2024-02-26 16:41:50 -06:00
parent 8a02e22a35
commit 0534d50af3
2 changed files with 49 additions and 40 deletions

View File

@@ -9,6 +9,7 @@ import sys
from io import StringIO from io import StringIO
from os.path import exists from os.path import exists
from typing import Callable, List, Any from typing import Callable, List, Any
import traceback
from ..h5 import H5File from ..h5 import H5File
from ..io import _RedirectOutput, gzip, _to_h5_filename from ..io import _RedirectOutput, gzip, _to_h5_filename
@@ -29,52 +30,57 @@ class BasicCollector:
verbose: bool = False, verbose: bool = False,
) -> None: ) -> None:
def _collect(data_filename: str) -> None: def _collect(data_filename: str) -> None:
h5_filename = _to_h5_filename(data_filename) try:
mps_filename = h5_filename.replace(".h5", ".mps") h5_filename = _to_h5_filename(data_filename)
mps_filename = h5_filename.replace(".h5", ".mps")
if exists(h5_filename): if exists(h5_filename):
# Try to read optimal solution # Try to read optimal solution
mip_var_values = None mip_var_values = None
try: try:
with H5File(h5_filename, "r") as h5: with H5File(h5_filename, "r") as h5:
mip_var_values = h5.get_array("mip_var_values") mip_var_values = h5.get_array("mip_var_values")
except: except:
pass pass
if mip_var_values is None: if mip_var_values is None:
print(f"Removing empty/corrupted h5 file: {h5_filename}") print(f"Removing empty/corrupted h5 file: {h5_filename}")
os.remove(h5_filename) os.remove(h5_filename)
else: else:
return return
with H5File(h5_filename, "w") as h5: with H5File(h5_filename, "w") as h5:
streams: List[Any] = [StringIO()] streams: List[Any] = [StringIO()]
if verbose: if verbose:
streams += [sys.stdout] streams += [sys.stdout]
with _RedirectOutput(streams): with _RedirectOutput(streams):
# Load and extract static features # Load and extract static features
model = build_model(data_filename) model = build_model(data_filename)
model.extract_after_load(h5) model.extract_after_load(h5)
if not self.skip_lp: if not self.skip_lp:
# Solve LP relaxation # Solve LP relaxation
relaxed = model.relax() relaxed = model.relax()
relaxed.optimize() relaxed.optimize()
relaxed.extract_after_lp(h5) relaxed.extract_after_lp(h5)
# Solve MIP # Solve MIP
model.optimize() model.optimize()
model.extract_after_mip(h5) model.extract_after_mip(h5)
if self.write_mps: if self.write_mps:
# Add lazy constraints to model # Add lazy constraints to model
model._lazy_enforce_collected() model._lazy_enforce_collected()
# Save MPS file # Save MPS file
model.write(mps_filename) model.write(mps_filename)
gzip(mps_filename) gzip(mps_filename)
h5.put_scalar("mip_log", streams[0].getvalue())
except:
print(f"Error processing: data_filename")
traceback.print_exc()
h5.put_scalar("mip_log", streams[0].getvalue())
if n_jobs > 1: if n_jobs > 1:
p_umap( p_umap(

View File

@@ -87,7 +87,10 @@ def read_pkl_gz(filename: str) -> Any:
def _to_h5_filename(data_filename: str) -> str: def _to_h5_filename(data_filename: str) -> str:
output = f"{data_filename}.h5" output = f"{data_filename}.h5"
output = output.replace(".gz.h5", ".h5") output = output.replace(".gz.h5", ".h5")
output = output.replace(".json.h5", ".h5") output = output.replace(".csv.h5", ".h5")
output = output.replace(".pkl.h5", ".h5")
output = output.replace(".jld2.h5", ".h5") output = output.replace(".jld2.h5", ".h5")
output = output.replace(".json.h5", ".h5")
output = output.replace(".lp.h5", ".h5")
output = output.replace(".mps.h5", ".h5")
output = output.replace(".pkl.h5", ".h5")
return output return output