Skip to content

Commit 0720ae7

Browse files
committed
fixing Athena column name sanitation to actually work in all cases
1 parent d28b265 commit 0720ae7

File tree

4 files changed

+20
-23
lines changed

4 files changed

+20
-23
lines changed

examples/models/smiles_to_md_v1.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
from workbench.api import FeatureSet, ModelType, Model
55
from workbench.utils.model_utils import get_custom_script_path
66

7-
fs_name = "aqsol_features"
8-
# fs_name = "solubility_featurized_class_0_fs"
7+
# fs_name = "aqsol_features"
8+
fs_name = "solubility_featurized_class_0_fs"
99

1010

1111
script_path = get_custom_script_path("chem_info", "molecular_descriptors.py")

src/workbench/model_scripts/custom_models/chem_info/molecular_descriptors.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def output_fn(output_df, accept_type):
7575
# Prediction function
7676
def predict_fn(df, model):
7777

78-
# Standardize the molecule (remove salts) and then compute descriptors
79-
df = standardize(df)
78+
# Standardize the molecule (extract salts) and then compute descriptors
79+
df = standardize(df, extract_salts=True)
8080
df = compute_descriptors(df)
8181
return df

src/workbench/utils/chem_utils/mol_descriptors.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -344,21 +344,16 @@ def compute_descriptors(df: pd.DataFrame, include_mordred: bool = True, include_
344344
stereo_count = len(stereo_df.columns) if include_stereo else 0
345345
logger.info(f"Descriptor breakdown: RDKit={rdkit_count}, Mordred={mordred_count}, Stereo={stereo_count}")
346346

347-
# Note: The results are often stored in an AWS Athena table.
348-
# Athena has restrictions on column names:
349-
# - Must be lowercase
350-
# - No special characters except underscore
351-
# - No spaces
352-
# https://docs.aws.amazon.com/athena/latest/ug/tables-databases-columns-names.html
353-
safe_columns = [re.sub(r"_+", "_", re.sub(r"[^a-z0-9_]", "_", col.lower())) for col in result.columns]
354-
355-
# Check for duplicates before dropping
356-
if len(safe_columns) != len(set(safe_columns)):
357-
from collections import Counter
358-
359-
duplicates = {col for col, count in Counter(safe_columns).items() if count > 1}
360-
logger.warning(f"Duplicate column names after sanitization: {duplicates} - dropping duplicates!")
361-
result.columns = safe_columns
347+
# Sanitize column names for AWS Athena compatibility
348+
# - Must be lowercase, no special characters except underscore, no spaces
349+
result.columns = [
350+
re.sub(r"_+", "_", re.sub(r"[^a-z0-9_]", "_", col.lower()))
351+
for col in result.columns
352+
]
353+
354+
# Drop duplicate columns if any exist after sanitization
355+
if result.columns.duplicated().any():
356+
logger.warning(f"Duplicate column names after sanitization - dropping duplicates!")
362357
result = result.loc[:, ~result.columns.duplicated()]
363358

364359
return result

src/workbench/utils/pandas_utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -152,17 +152,19 @@ def compare_dataframes(df1: pd.DataFrame, df2: pd.DataFrame, display_columns: li
152152

153153
# Check for differences in common columns
154154
for column in common_columns:
155-
if pd.api.types.is_string_dtype(df1[column]) or pd.api.types.is_string_dtype(df2[column]):
155+
if pd.api.types.is_string_dtype(df1[column]) and pd.api.types.is_string_dtype(df2[column]):
156156
# String comparison with NaNs treated as equal
157157
differences = ~(df1[column].fillna("") == df2[column].fillna(""))
158158
elif pd.api.types.is_float_dtype(df1[column]) or pd.api.types.is_float_dtype(df2[column]):
159159
# Float comparison within epsilon with NaNs treated as equal
160160
differences = ~((df1[column] - df2[column]).abs() <= epsilon) & ~(
161-
pd.isna(df1[column]) & pd.isna(df2[column])
161+
pd.isna(df1[column]) & pd.isna(df2[column])
162162
)
163163
else:
164-
# Other types (e.g., int) with NaNs treated as equal
165-
differences = ~(df1[column].fillna(0) == df2[column].fillna(0))
164+
# Other types (int, Int64, etc.) - compare with NaNs treated as equal
165+
differences = (df1[column] != df2[column]) & ~(
166+
pd.isna(df1[column]) & pd.isna(df2[column])
167+
)
166168

167169
# If differences exist, display them
168170
if differences.any():

0 commit comments

Comments
 (0)