mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 01:18:52 -06:00
Train without loading all instances to memory
This commit is contained in:
@@ -6,6 +6,7 @@ import logging
|
||||
import pickle
|
||||
import os
|
||||
import tempfile
|
||||
import gzip
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Optional, List
|
||||
@@ -198,11 +199,18 @@ class LearningSolver:
|
||||
"""
|
||||
|
||||
filename = None
|
||||
fileformat = None
|
||||
if isinstance(instance, str):
|
||||
filename = instance
|
||||
logger.info("Reading: %s" % filename)
|
||||
with open(filename, "rb") as file:
|
||||
instance = pickle.load(file)
|
||||
if filename.endswith(".gz"):
|
||||
fileformat = "pickle-gz"
|
||||
with gzip.GzipFile(filename, "rb") as file:
|
||||
instance = pickle.load(file)
|
||||
else:
|
||||
fileformat = "pickle"
|
||||
with open(filename, "rb") as file:
|
||||
instance = pickle.load(file)
|
||||
|
||||
if model is None:
|
||||
model = instance.to_model()
|
||||
@@ -260,9 +268,12 @@ class LearningSolver:
|
||||
if len(output) == 0:
|
||||
output_filename = filename
|
||||
logger.info("Writing: %s" % output_filename)
|
||||
with tempfile.NamedTemporaryFile(delete=False) as tmp:
|
||||
pickle.dump(instance, tmp)
|
||||
os.replace(tmp.name, output_filename)
|
||||
if fileformat == "pickle":
|
||||
with open(output_filename, "wb") as file:
|
||||
pickle.dump(instance, file)
|
||||
else:
|
||||
with gzip.GzipFile(output_filename, "wb") as file:
|
||||
pickle.dump(instance, file)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
Reference in New Issue
Block a user