mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-06 09:28:51 -06:00
Implement load; update fit
This commit is contained in:
@@ -6,7 +6,7 @@ import gc
|
||||
import gzip
|
||||
import os
|
||||
import pickle
|
||||
from typing import Optional, Any, List, cast, IO, TYPE_CHECKING, Dict
|
||||
from typing import Optional, Any, List, cast, IO, TYPE_CHECKING, Dict, Callable
|
||||
|
||||
import numpy as np
|
||||
from overrides import overrides
|
||||
@@ -177,3 +177,9 @@ def save(objs: List[Any], dirname: str) -> List[str]:
|
||||
filenames.append(filename)
|
||||
write_pickle_gz(obj, filename)
|
||||
return filenames
|
||||
|
||||
|
||||
def load(filename: str, build_model: Callable) -> Any:
|
||||
with gzip.GzipFile(filename, "rb") as file:
|
||||
data = pickle.load(cast(IO[bytes], file))
|
||||
return build_model(data)
|
||||
|
||||
Reference in New Issue
Block a user