Skip to content

Commit 761cd14

Browse files
committed
add options to log histograms
1 parent b510464 commit 761cd14

File tree

1 file changed

+23
-0
lines changed

1 file changed

+23
-0
lines changed

pe/population/pe_population.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
import os
23

34
from .population import Population
45
from pe.data import Data
@@ -7,6 +8,7 @@
78
from pe.constant.data import PARENT_SYN_DATA_INDEX_COLUMN_NAME
89
from pe.constant.data import FROM_LAST_FLAG_COLUMN_NAME
910
from pe.constant.data import VARIATION_API_FOLD_ID_COLUMN_NAME
11+
from pe.constant.data import LABEL_ID_COLUMN_NAME
1012
from pe.logging import execution_logger
1113

1214

@@ -21,6 +23,7 @@ def __init__(
2123
next_variation_api_fold=1,
2224
keep_selected=False,
2325
selection_mode="sample",
26+
histogram_log_folder=None,
2427
):
2528
"""Constructor.
2629
@@ -39,6 +42,9 @@ def __init__(
3942
random sampling proportional to the histogram), "rank" (select the top samples according to the histogram).
4043
Defaults to "sample"
4144
:type selection_mode: str, optional
45+
:param histogram_log_folder: The folder to save the logs of the histogram. If it is None, the logs are not
46+
saved. Defaults to None
47+
:type histogram_log_folder: str, optional
4248
:raises ValueError: If next_variation_api_fold is 0 and keep_selected is False
4349
"""
4450
super().__init__()
@@ -53,6 +59,7 @@ def __init__(
5359
"next_variation_api_fold should be greater than 0 or keep_selected should be True. Otherwise, next "
5460
"synthetic data will be empty."
5561
)
62+
self._histogram_log_folder = histogram_log_folder
5663

5764
def initial(self, label_info, num_samples):
5865
"""Generate the initial synthetic data.
@@ -98,6 +105,7 @@ def _post_process_histogram(self, syn_data):
98105
else:
99106
clipped_count = count
100107
syn_data.data_frame[POST_PROCESSED_DP_HISTOGRAM_COLUMN_NAME] = clipped_count
108+
self._log_histogram(syn_data)
101109
return syn_data
102110

103111
def _select_data(self, syn_data, num_samples):
@@ -127,6 +135,21 @@ def _select_data(self, syn_data, num_samples):
127135
else:
128136
raise ValueError(f"Selection mode {self._selection_mode} is not supported")
129137

138+
def _log_histogram(self, syn_data):
139+
"""Log the histogram.
140+
141+
:param syn_data: The synthetic data with the histogram
142+
:type syn_data: :py:class:`pe.data.Data`
143+
"""
144+
if self._histogram_log_folder is None:
145+
return
146+
labels = set(list(syn_data.data_frame[LABEL_ID_COLUMN_NAME].values))
147+
assert len(labels) == 1
148+
label = list(labels)[0]
149+
iteration = syn_data.metadata["iteration"]
150+
log_folder = os.path.join(self._histogram_log_folder, f"{iteration}", f"label-id{label}")
151+
syn_data.save_checkpoint(log_folder)
152+
130153
def next(self, syn_data, num_samples):
131154
"""Generate the next synthetic data.
132155

0 commit comments

Comments
 (0)