11import numpy as np
2+ import os
23
34from .population import Population
45from pe .data import Data
78from pe .constant .data import PARENT_SYN_DATA_INDEX_COLUMN_NAME
89from pe .constant .data import FROM_LAST_FLAG_COLUMN_NAME
910from pe .constant .data import VARIATION_API_FOLD_ID_COLUMN_NAME
11+ from pe .constant .data import LABEL_ID_COLUMN_NAME
1012from pe .logging import execution_logger
1113
1214
@@ -21,6 +23,7 @@ def __init__(
2123 next_variation_api_fold = 1 ,
2224 keep_selected = False ,
2325 selection_mode = "sample" ,
26+ histogram_log_folder = None ,
2427 ):
2528 """Constructor.
2629
@@ -39,6 +42,9 @@ def __init__(
3942 random sampling proportional to the histogram), "rank" (select the top samples according to the histogram).
4043 Defaults to "sample"
4144 :type selection_mode: str, optional
45+ :param histogram_log_folder: The folder to save the logs of the histogram. If it is None, the logs are not
46+ saved. Defaults to None
47+ :type histogram_log_folder: str, optional
4248 :raises ValueError: If next_variation_api_fold is 0 and keep_selected is False
4349 """
4450 super ().__init__ ()
@@ -53,6 +59,7 @@ def __init__(
5359 "next_variation_api_fold should be greater than 0 or keep_selected should be True. Otherwise, next "
5460 "synthetic data will be empty."
5561 )
62+ self ._histogram_log_folder = histogram_log_folder
5663
5764 def initial (self , label_info , num_samples ):
5865 """Generate the initial synthetic data.
@@ -98,6 +105,7 @@ def _post_process_histogram(self, syn_data):
98105 else :
99106 clipped_count = count
100107 syn_data .data_frame [POST_PROCESSED_DP_HISTOGRAM_COLUMN_NAME ] = clipped_count
108+ self ._log_histogram (syn_data )
101109 return syn_data
102110
103111 def _select_data (self , syn_data , num_samples ):
@@ -127,6 +135,21 @@ def _select_data(self, syn_data, num_samples):
127135 else :
128136 raise ValueError (f"Selection mode { self ._selection_mode } is not supported" )
129137
138+ def _log_histogram (self , syn_data ):
139+ """Log the histogram.
140+
141+ :param syn_data: The synthetic data with the histogram
142+ :type syn_data: :py:class:`pe.data.Data`
143+ """
144+ if self ._histogram_log_folder is None :
145+ return
146+ labels = set (list (syn_data .data_frame [LABEL_ID_COLUMN_NAME ].values ))
147+ assert len (labels ) == 1
148+ label = list (labels )[0 ]
149+ iteration = syn_data .metadata ["iteration" ]
150+ log_folder = os .path .join (self ._histogram_log_folder , f"{ iteration } " , f"label-id{ label } " )
151+ syn_data .save_checkpoint (log_folder )
152+
130153 def next (self , syn_data , num_samples ):
131154 """Generate the next synthetic data.
132155
0 commit comments