mirror of
https://github.com/ANL-CEEESA/LLEPE.git
synced 2025-12-06 01:48:53 -06:00
Updated iterative_fitter.py to calculate error for all species. Added new test with mean squared error.
This commit is contained in:
@@ -57,3 +57,29 @@ def ind_lmse_perturbed_obj(predicted_dict,
|
||||
objectives.append(np.mean(fun1))
|
||||
objectives = np.array(objectives)
|
||||
return objectives
|
||||
|
||||
|
||||
def mean_squared_error(predicted_dict,
|
||||
measured_df,
|
||||
species_list):
|
||||
meas_aq = np.concatenate([measured_df['{0}_aq_eq'.format(species)].values
|
||||
for species in species_list])
|
||||
pred_aq = np.concatenate([
|
||||
predicted_dict['{0}_aq_eq'.format(species)]
|
||||
for species in species_list])
|
||||
|
||||
meas_d = np.concatenate([measured_df['{0}_d_eq'.format(species)].values
|
||||
for species in species_list])
|
||||
pred_d = np.concatenate([
|
||||
predicted_dict['{0}_d_eq'.format(species)]
|
||||
for species in species_list])
|
||||
|
||||
meas_org = meas_aq * meas_d
|
||||
pred_org = np.concatenate([
|
||||
predicted_dict['{0}_org_eq'.format(species)]
|
||||
for species in species_list])
|
||||
aq_obj = (meas_aq - pred_aq)**2
|
||||
org_obj = (meas_org - pred_org)**2
|
||||
objs = np.concatenate([aq_obj, org_obj])
|
||||
obj = np.mean(objs)
|
||||
return obj
|
||||
|
||||
Reference in New Issue
Block a user