diff --git a/molSimplify/Classes/globalvars.py b/molSimplify/Classes/globalvars.py index 48fa3006..73eceff1 100644 --- a/molSimplify/Classes/globalvars.py +++ b/molSimplify/Classes/globalvars.py @@ -338,75 +338,10 @@ 'y': 1, 'zr': 2, 'nb': 3, 'mo': 4, 'tc': 5, 'ru': 6, 'rh': 7, 'pd': 8, 'ag': 9, 'cd': 10, 'hf': 2, 'ta': 3, 'w': 4, 're': 5, 'os': 6, 'ir': 7, 'pt': 8, 'au': 9, 'hg': 10} -# Default spins for each d-electron count (metal/oxidation state-specific version in metal_ox_spinlist) +# Default spins for each d-electron count (make this metal/oxidation state specific) defaultspins = {0: '1', 1: '2', 2: '3', 3: '4', 4: '5', 5: '6', 6: '5', 7: '4', 8: '3', 9: '2', 10: '1'} -# d-electron counts for each transition metal and oxidation state -# ground states determined from NIST Atomic Spectra Database https://physics.nist.gov/PhysRefData/ASD/levels_form.html, accessed 2025 -metal_ox_dlist = {'Ag(1)': 10, 'Ag(2)': 9, 'Ag(3)': 8, - 'Au(0)': 10, 'Au(1)': 10, 'Au(2)': 9, 'Au(3)': 8, 'Au(4)': 7, 'Au(5)': 6, - 'Cd(1)': 10, 'Cd(2)': 10, - 'Co(0)': 7, 'Co(1)': 8, 'Co(2)': 7, 'Co(3)': 6, 'Co(4)': 5, 'Co(5)': 4, - 'Cr(0)': 5, 'Cr(1)': 5, 'Cr(2)': 4, 'Cr(3)': 3, 'Cr(4)': 2, 'Cr(5)': 1, 'Cr(6)': 0, 'Cr(7)': 0, - 'Cu(0)': 10, 'Cu(1)': 10, 'Cu(2)': 9, 'Cu(3)': 8, 'Cu(4)': 7, - 'Fe(0)': 6, 'Fe(1)': 6, 'Fe(2)': 6, 'Fe(3)': 5, 'Fe(4)': 4, 'Fe(5)': 3, - 'Hf(0)': 2, 'Hf(2)': 2, 'Hf(3)': 1, 'Hf(4)': 0, - 'Hg(1)': 10, 'Hg(2)': 10, - 'Ir(0)': 7, 'Ir(1)': 7, 'Ir(2)': 7, 'Ir(3)': 6, 'Ir(4)': 5, 'Ir(5)': 4, - 'Mn(0)': 5, 'Mn(1)': 5, 'Mn(2)': 5, 'Mn(3)': 4, 'Mn(4)': 3, 'Mn(5)': 2, 'Mn(6)': 1, 'Mn(7)': 0, - 'Mo(0)': 5, 'Mo(1)': 5, 'Mo(2)': 4, 'Mo(3)': 3, 'Mo(4)': 2, 'Mo(5)': 1, 'Mo(6)': 0, 'Mo(7)': 0, - 'Nb(0)': 4, 'Nb(1)': 4, 'Nb(2)': 3, 'Nb(3)': 2, 'Nb(4)': 1, 'Nb(5)': 0, - 'Ni(0)': 8, 'Ni(1)': 9, 'Ni(2)': 8, 'Ni(3)': 7, 'Ni(4)': 6, 'Ni(5)': 5, - 'Os(0)': 6, 'Os(1)': 6, 'Os(2)': 6, 'Os(3)': 5, 'Os(4)': 4, 'Os(5)': 3, 'Os(6)': 2, 'Os(8)': 0, - 'Pd(0)': 10, 'Pd(1)': 9, 'Pd(2)': 8, 'Pd(3)': 7, 'Pd(4)': 6, - 'Pt(0)': 9, 'Pt(1)': 9, 'Pt(2)': 8, 'Pt(3)': 7, 'Pt(4)': 6, 'Pt(6)': 4, - 'Re(0)': 5, 'Re(1)': 5, 'Re(2)': 5, 'Re(3)': 4, 'Re(4)': 3, 'Re(5)': 2, 'Re(6)': 1, 'Re(7)': 0, - 'Rh(0)': 8, 'Rh(1)': 8, 'Rh(2)': 7, 'Rh(3)': 6, 'Rh(4)': 5, 'Rh(5)': 4, - 'Ru(0)': 7, 'Ru(1)': 7, 'Ru(2)': 6, 'Ru(3)': 5, 'Ru(4)': 4, 'Ru(5)': 3, 'Ru(6)': 2, - 'Sc(1)': 1, 'Sc(2)': 1, 'Sc(3)': 0, - 'Ta(0)': 3, 'Ta(1)': 3, 'Ta(2)': 3, 'Ta(3)': 2, 'Ta(4)': 1, 'Ta(5)': 0, - 'Tc(0)': 5, 'Tc(1)': 5, 'Tc(2)': 5, 'Tc(3)': 4, 'Tc(4)': 3, 'Tc(5)': 2, 'Tc(6)': 1, 'Tc(7)': 0, - 'Ti(0)': 2, 'Ti(1)': 2, 'Ti(2)': 2, 'Ti(3)': 1, 'Ti(4)': 0, - 'V(0)': 3, 'V(1)': 4, 'V(2)': 3, 'V(3)': 2, 'V(4)': 1, 'V(5)': 0, - 'W(0)': 4, 'W(1)': 4, 'W(2)': 4, 'W(3)': 3, 'W(4)': 2, 'W(5)': 1, 'W(6)': 0, - 'Y(2)': 1, 'Y(3)': 0, 'Y(4)': 0, 'Y(5)': 0, - 'Zn(0)': 10, 'Zn(1)': 10, 'Zn(2)': 10, 'Zn(4)': 8, 'Zn(5)': 7, - 'Zr(0)': 2, 'Zr(2)': 2, 'Zr(3)': 1, 'Zr(4)': 0} - -# allowed spins (low, intermediate, high) for each transition metal and oxidation state -# assumes only 3d metals exhibit high spin states, all exhibit low and intermediate spin states -# determined from NIST Atomic Spectra Database https://physics.nist.gov/PhysRefData/ASD/levels_form.html, accessed 2025 -metal_ox_spinlist = {'Ag(1)': [1, 3], 'Ag(2)': [2, 4], 'Ag(3)': [1, 3], - 'Au(0)': [2, 4], 'Au(1)': [1, 3], 'Au(2)': [2, 4], 'Au(3)': [1, 3], 'Au(4)': [2, 4], 'Au(5)': [1, 3], - 'Cd(1)': [2, 4], 'Cd(2)': [1, 3], - 'Co(0)': [2, 4, 6], 'Co(1)': [1, 3, 5], 'Co(2)': [2, 4, 6], 'Co(3)': [1, 3, 5], 'Co(4)': [2, 4, 6], 'Co(5)': [1, 3, 5], - 'Cr(0)': [1, 3, 5], 'Cr(1)': [2, 4, 6], 'Cr(2)': [1, 3, 5], 'Cr(3)': [2, 4, 6], 'Cr(4)': [1, 3, 5], 'Cr(5)': [2, 4, 6], 'Cr(6)': [1, 3, 5], 'Cr(7)': [2, 4, 6], - 'Cu(0)': [2, 4, 6], 'Cu(1)': [1, 3, 5], 'Cu(2)': [2, 4, 6], 'Cu(3)': [1, 3, 5], 'Cu(4)': [2, 4, 6], - 'Fe(0)': [1, 3, 5], 'Fe(1)': [2, 4, 6], 'Fe(2)': [1, 3, 5], 'Fe(3)': [2, 4, 6], 'Fe(4)': [1, 3, 5], 'Fe(5)': [2, 4, 6], - 'Hf(0)': [1, 3], 'Hf(2)': [1, 3], 'Hf(3)': [2, 4], 'Hf(4)': [1, 3], - 'Hg(1)': [2, 4], 'Hg(2)': [1, 3], - 'Ir(0)': [2, 4], 'Ir(1)': [1, 3], 'Ir(2)': [2, 4], 'Ir(3)': [1, 3], 'Ir(4)': [2, 4], 'Ir(5)': [1, 3], - 'Mn(0)': [2, 4, 6], 'Mn(1)': [1, 3, 5], 'Mn(2)': [2, 4, 6], 'Mn(3)': [1, 3, 5], 'Mn(4)': [2, 4, 6], 'Mn(5)': [1, 3, 5], 'Mn(6)': [2, 4, 6], 'Mn(7)': [1, 3, 5], - 'Mo(0)': [1, 3], 'Mo(1)': [2, 4], 'Mo(2)': [1, 3], 'Mo(3)': [2, 4], 'Mo(4)': [1, 3], 'Mo(5)': [2, 4], 'Mo(6)': [1, 3], 'Mo(7)': [2, 4], - 'Nb(0)': [2, 4], 'Nb(1)': [1, 3], 'Nb(2)': [2, 4], 'Nb(3)': [1, 3], 'Nb(4)': [2, 4], 'Nb(5)': [1, 3], - 'Ni(0)': [1, 3, 5],'Ni(1)': [2, 4, 6], 'Ni(2)': [1, 3, 5], 'Ni(3)': [2, 4, 6], 'Ni(4)': [1, 3, 5], 'Ni(5)': [2, 4, 6], - 'Os(0)': [1, 3], 'Os(1)': [2, 4], 'Os(2)': [1, 3], 'Os(3)': [2, 4], 'Os(4)': [1, 3], 'Os(5)': [2, 4], 'Os(6)': [1, 3], 'Os(8)': [1, 3], - 'Pd(0)': [1, 3], 'Pd(1)': [2, 4], 'Pd(2)': [1, 3], 'Pd(3)': [2, 4], 'Pd(4)': [1, 3], - 'Pt(0)': [1, 3], 'Pt(1)': [2, 4], 'Pt(2)': [1, 3], 'Pt(3)': [2, 4], 'Pt(4)': [1, 3], 'Pt(6)': [1, 3], - 'Re(0)': [2, 4], 'Re(1)': [1, 3], 'Re(2)': [2, 4], 'Re(3)': [1, 3], 'Re(4)': [2, 4], 'Re(5)': [1, 3], 'Re(6)': [2, 4], 'Re(7)': [1, 3], - 'Rh(0)': [2, 4], 'Rh(1)': [1, 3], 'Rh(2)': [2, 4], 'Rh(3)': [1, 3], 'Rh(4)': [2, 4], 'Rh(5)': [1, 3], - 'Ru(0)': [1, 3], 'Ru(1)': [2, 4], 'Ru(2)': [1, 3], 'Ru(3)': [2, 4], 'Ru(4)': [1, 3], 'Ru(5)': [2, 4], 'Ru(6)': [1, 3], - 'Sc(1)': [1, 3, 5], 'Sc(2)': [2, 4, 6], 'Sc(3)': [1, 3, 5], - 'Ta(0)': [2, 4], 'Ta(1)': [1, 3], 'Ta(2)': [2, 4], 'Ta(3)': [1, 3], 'Ta(4)': [2, 4], 'Ta(5)': [1, 3], - 'Tc(0)': [2, 4], 'Tc(1)': [1, 3], 'Tc(2)': [2, 4],'Tc(3)': [1, 3], 'Tc(4)': [2, 4], 'Tc(5)': [1, 3], 'Tc(6)': [2, 4], 'Tc(7)': [1, 3], - 'Ti(0)': [1, 3, 5], 'Ti(1)': [2, 4, 6], 'Ti(2)': [1, 3, 5], 'Ti(3)': [2, 4, 6], 'Ti(4)': [1, 3, 5], - 'V(0)': [2, 4, 6], 'V(1)': [1, 3, 5], 'V(2)': [2, 4, 6], 'V(3)': [1, 3, 5], 'V(4)': [2, 4, 6], 'V(5)': [1, 3, 5], - 'W(0)': [1, 3], 'W(1)': [2, 4], 'W(2)': [1, 3], 'W(3)': [2, 4], 'W(4)': [1, 3], 'W(5)': [2, 4], 'W(6)': [1, 3], - 'Y(2)': [2, 4], 'Y(3)': [1, 3], 'Y(4)': [2, 4], 'Y(5)': [1, 3], - 'Zn(0)': [1, 3, 5], 'Zn(1)': [2, 4, 6], 'Zn(2)': [1, 3, 5], 'Zn(4)': [1, 3, 5], 'Zn(5)': [2, 4, 6], - 'Zr(0)': [1, 3], 'Zr(2)': [1, 3], 'Zr(3)': [2, 4], 'Zr(4)': [1, 3]} - # Elements sorted by atomic number elementsbynum = ['H', 'He', 'Li', 'Be', 'B', 'C', 'N', 'O', 'F', 'Ne', diff --git a/molSimplify/Scripts/enhanced_structgen.py b/molSimplify/Scripts/enhanced_structgen.py index bb65b468..4071b62e 100644 --- a/molSimplify/Scripts/enhanced_structgen.py +++ b/molSimplify/Scripts/enhanced_structgen.py @@ -20,6 +20,7 @@ from molSimplify.utils.openbabel_helpers import * from molSimplify.Scripts.io import lig_load from molSimplify.Classes.globalvars import vdwrad +from molSimplify.Classes.globalvars import globalvars from molSimplify.Scripts.io import getlicores, lig_load_safe, parse_bracketed_list from typing import Any, List, Dict, Tuple, Union, Optional @@ -206,6 +207,10 @@ def generate_complex( # Accumulate ALL ligands' haptic groups (GLOBAL indices) so we can re-apply anytime all_haptic_groups_global = [] + batslist = [] + # NEW: collect (backbone_site_1based, core_atom_index_1based) for every filled site + backbone_core_pairs = [] # NEW + iteration = 0 for ligand in ligand_list: donor_indices = list(ligand[1]) # may include lists (haptics) or ints @@ -235,6 +240,7 @@ def generate_complex( denticity = len(donor_groups) structure = get_next_structure(metals_structures_copy, denticity) + if manual and iteration < len(manual_list): valid_subsets = [manual_list[iteration]] else: @@ -244,6 +250,7 @@ def generate_complex( print(f"Valid subsets: {valid_subsets}") print(structure) + prev = structure['occupied_mask'].copy() # group-aware Kabsch best_subset, best_aligned_coords, best_rmsd, placement_attempts, best_perm_idx = clash_aware_kabsch( ligand_all_coords, @@ -276,6 +283,13 @@ def generate_complex( assert best_subset != None, "Check total number of coordination sites in tested geometry, as well as total number of coordinating atoms" structure['occupied_mask'][np.array(best_subset)] = True + curr = np.array(structure['occupied_mask'].copy()) + # Find where a value flipped from False → True + new_true = np.where((~prev) & curr)[0] + 1 # +1 for 1-based indexing + batslist.append(list(new_true)) + # Update stored state + prev = curr.copy() + # group→site mapping (not strictly needed for bonding) donor_to_site, site_indices_in_group_order = map_donors_to_sites_haptic_aware( donor_groups=donor_groups, @@ -330,6 +344,14 @@ def generate_complex( ) # --------------------------------------------------------------------------------- + # NEW: record which core3D atom now occupies each backbone site filled by this ligand + # Use the representative donor atom for each donor group (donor_reps). + # Both backbone site indices and core atom indices reported as 1-based. + for site_idx_0b, rep_local_0b in zip(site_indices_in_group_order, donor_reps): # both 0-based + core_idx_1b = int(local2global[rep_local_0b]) + 1 # assume 0-based mapping → make 1-based + backbone_core_pairs.append((int(site_idx_0b) + 1, core_idx_1b)) # (backbone_site_1b, core_atom_1b) + # END NEW + # clean bonds & optimize core3D.convert2OBMol(force_clean=True) replace_bonds(core3D.OBMol, core3D.bo_dict) @@ -423,8 +445,14 @@ def generate_complex( # get sterics report if run_sterics: clashes, severity, fig = run_sterics_check(core3D, max_steps, ff_name) + else: + fig = None # just to be explicit + + # NEW: produce the final ordered list of core indices per filled backbone site + backbone_core_indices = [core for (site, core) in sorted(backbone_core_pairs, key=lambda x: x[0])] # NEW - return core3D, clashes, severity, fig + # OLD + NEW: append backbone_core_indices at the end + return core3D, clashes, severity, fig, batslist, backbone_core_indices def run_sterics_check(core3D, max_steps, ff_name): optimized_coords, per_atom_ff_force = constrained_forcefield_optimization( @@ -691,3 +719,174 @@ def _parse_usercatoms(s: str): print(f"[ok] Wrote complex to {out_path} (format={out_fmt})") except Exception as e: print(f"[warn] Could not write file ({e}). Returning object only.") + + +def enhanced_init_ANN(metal, ox, spin, ligands, occs, dents, + batslist, tcats, licores, geometry): + """Initializes ANN. + + Parameters + ---------- + args : Namespace + Namespace of arguments. + ligands : list + List of ligands, given as names. + occs : list + List of ligand occupations (frequencies of each ligand). + dents : list + List of ligand denticities. + batslist : list + List of backbond points. + tcats : list + List of SMILES ligand connecting atoms. + licores : dict + Ligand dictionary within molSimplify. + + Returns + ------- + ANN_flag : bool + Whether an ANN call was successful. + ANN_bondl : float + ANN predicted bond length. + ANN_reason : str + Reason for ANN failure, if failed. + ANN_attributes : dict + Dictionary of predicted attributes of complex. + catalysis_flag : bool + Whether or not complex is compatible for catalytic ANNs. + + """ + # initialize ANN + globs = globalvars() + catalysis_flag = False + + # new RACs-ANN + from molSimplify.Scripts.tf_nn_prep import tf_ANN_preproc + # Set default value [] in case decoration is not used + decoration_index = [] + + ANN_flag, ANN_reason, ANN_attributes, catalysis_flag = tf_ANN_preproc( + metal, ox, spin, ligands, occs, dents, batslist, + tcats, licores, False, decoration_index, False, + geometry, debug=False) + + if ANN_flag: + ANN_bondl = ANN_attributes['ANN_bondl'] + else: + # there needs to be 1 length per possible lig + ANN_bondl = len( + [item for items in batslist for item in items])*[False] + return ANN_flag, ANN_bondl, ANN_reason, ANN_attributes, catalysis_flag + +import numpy as np + +def enforce_metal_ligand_distances_and_optimize( + core3D, + bondl, + backbone_core_indices, # 1-based indices + *, + ff_name: str = "UFF", + max_steps: int = 500, + constrain: bool = True, + tolerate_zero_vec: float = 1e-6, +): + """ + Adjust metal–ligand distances and perform a constrained FF optimization. + Assumes backbone_core_indices are 1-based. If bondl is None, per-site targets + are set to (metal covalent radius + donor covalent radius) in Å. + """ + + # --- helper: robust scalar cast (accepts scalar or length-1 array/list) --- + def _as_float(x): + if isinstance(x, (list, tuple, np.ndarray)): + arr = np.asarray(x, dtype=float).reshape(-1) + if arr.size != 1: + raise TypeError(f"Each bond length must be a scalar; got shape {arr.shape} with values {arr}.") + return float(arr[0]) + return float(x) + + # 1) Convert to 0-based indices for internal use + donor_idxs = [int(i) - 1 for i in backbone_core_indices] + + # 2) Coords & elements + coords, elements = get_all_coords_and_elements(core3D) + coords = np.asarray(coords, dtype=float) + N = len(coords) + + # 3) Identify metal centers + try: + metal_indices = core3D.findMetal(transition_metals_only=True) + except Exception: + tm_set = { + "Sc","Ti","V","Cr","Mn","Fe","Co","Ni","Cu","Zn", + "Y","Zr","Nb","Mo","Tc","Ru","Rh","Pd","Ag","Cd", + "La","Hf","Ta","W","Re","Os","Ir","Pt","Au","Hg" + } + metal_indices = [i for i, sym in enumerate(elements) if sym in tm_set] + + if not metal_indices: + raise RuntimeError("No metal centers found in core3D.") + + metal_indices = sorted(set(int(i) for i in metal_indices)) + metal_coords = coords[metal_indices] + + # 4) Quick bond lookup (0-based pairs) + bonded_pairs = {(min(a, b), max(a, b)) for (a, b) in core3D.bo_dict} + + def donor_metal_for(idx): + """Prefer the metal it's bonded to; else nearest metal geometrically.""" + for m in metal_indices: + if (min(idx, m), max(idx, m)) in bonded_pairs: + return m + diffs = metal_coords - coords[idx] + return metal_indices[int(np.argmin(np.sum(diffs**2, axis=1)))] + + # --- NEW: if bondl is None, build it from covalent radii sums --- + if bondl is None: + bondl = [] + for donor_idx in donor_idxs: + if donor_idx < 0 or donor_idx >= N: + raise IndexError(f"Donor index {donor_idx+1} out of bounds (1-based input).") + m_idx = donor_metal_for(donor_idx) + r_m = float(core3D.getAtom(m_idx).rad) # Å + r_l = float(core3D.getAtom(donor_idx).rad) # Å + bondl.append(r_m + r_l) + else: + if len(bondl) != len(donor_idxs): + raise ValueError("bondl and backbone_core_indices must have the same length") + + # 5) Move each donor along its M–L line to the target distance + new_coords = coords.copy() + for target_dist_raw, donor_idx in zip(bondl, donor_idxs): + if donor_idx < 0 or donor_idx >= N: + raise IndexError(f"Donor index {donor_idx+1} out of bounds (1-based input).") + + target_dist = _as_float(target_dist_raw) + + m_idx = donor_metal_for(donor_idx) + M = coords[m_idx] + L = coords[donor_idx] + + vec = L - M + dist = float(np.linalg.norm(vec)) + if dist < tolerate_zero_vec: + direction = np.array([0.0, 0.0, 1.0], dtype=float) + else: + direction = vec / dist + + new_coords[donor_idx] = M + direction * target_dist + + core3D = set_new_coords(core3D, new_coords) + + # 6) Constrained FF optimization (freeze metals + donors) + if constrain: + frozen = sorted(set(metal_indices + donor_idxs)) + optimized = constrained_forcefield_optimization( + core3D, + frozen, + max_steps=max_steps, + ff_name=ff_name, + ) + core3D = set_new_coords(core3D, optimized) + + return core3D diff --git a/molSimplify/Scripts/tf_nn_prep.py b/molSimplify/Scripts/tf_nn_prep.py index 0586b25c..27364e62 100644 --- a/molSimplify/Scripts/tf_nn_prep.py +++ b/molSimplify/Scripts/tf_nn_prep.py @@ -713,10 +713,18 @@ def evaluate_tmc_anns(this_complex: mol3D, metal: str, ox: int, spin: int, print( 'warning, ANN predicts a near degenerate ground state for this complex') print(f"ANN predicts a spin splitting (HS - LS) of {float(split[0]):.2f} kcal/mol at {100 * alpha:.0f}% HFX") - print('ANN low spin bond length (ax1/ax2/eq) is predicted to be: ' + " /".join( - [f"{float(i):.2f}" for i in r_ls[0]]) + ' angstrom') - print('ANN high spin bond length (ax1/ax2/eq) is predicted to be: ' + " /".join( - [f"{float(i):.2f}" for i in r_hs[0]]) + ' angstrom') + + try: + print('ANN low spin bond length (ax1/ax2/eq) is predicted to be: ' + " /".join( + [f"{float(i):.2f}" for i in r_ls[0]]) + ' angstrom') + print('ANN high spin bond length (ax1/ax2/eq) is predicted to be: ' + " /".join( + [f"{float(i):.2f}" for i in r_hs[0]]) + ' angstrom') + except: + vals = _flat_floats(r_ls) + print('ANN low spin bond length (ax1/ax2/eq) is predicted to be: ' + ', '.join(f"{v:.2f}" for v in vals) + ' angstrom') + vals = _flat_floats(r_hs) + print('ANN high spin bond length (ax1/ax2/eq) is predicted to be: ' + ', '.join(f"{v:.2f}" for v in vals) + ' angstrom') + print(f'distance to splitting energy training data is {split_dist:.2f}') print(splitting_ANN_trust_message) print() @@ -734,6 +742,14 @@ def evaluate_tmc_anns(this_complex: mol3D, metal: str, ox: int, spin: int, return ANN_attributes +def _flat_floats(x): + import numpy as np + arr = np.asarray(x, dtype=float) + if arr.ndim > 1: + arr = arr[0] + return arr.ravel() + + def evaluate_catalytic_anns(this_complex: mol3D, metal: str, ox: int, spin: int, custom_ligand_dict: Dict[str, list], net_lig_charge: int, exchange: Union[str, float, int], diff --git a/molSimplify/__main__.py b/molSimplify/__main__.py index 8e8ddc51..21c44e4b 100644 --- a/molSimplify/__main__.py +++ b/molSimplify/__main__.py @@ -26,7 +26,9 @@ # fix OB bug: https://github.com/openbabel/openbabel/issues/1983 import sys import os +import numpy as np import argparse +from molSimplify.Scripts.io import getlicores if not ('win' in sys.platform): flags = sys.getdlopenflags() if not ('win' in sys.platform): @@ -40,7 +42,7 @@ parseinputs_ligdict, parseinputs_basic, parseCLI) from molSimplify.Scripts.generator import startgen -from molSimplify.Classes.globalvars import globalvars +from molSimplify.Classes.globalvars import globalvars, geometry_vectors from molSimplify.utils.tensorflow import tensorflow_silence @@ -129,6 +131,8 @@ def main(args=None): from molSimplify.Scripts.enhanced_structgen import ( create_ligand_list, generate_complex, + enhanced_init_ANN, + enforce_metal_ligand_distances_and_optimize ) from molSimplify.Scripts.enhanced_structgen_functionality import check_badjob @@ -177,6 +181,8 @@ def _parse_usercatoms(s: str): # Common build knobs (pass-through to generate_complex) parser.add_argument("--metal", default="Fe") + parser.add_argument("--ox", default=2) + parser.add_argument("--spin", default=1) parser.add_argument("--geometry", default="octahedral") parser.add_argument("--voxel-size", type=float, default=0.5) parser.add_argument("--vdw-scale", type=float, default=0.8) @@ -185,6 +191,7 @@ def _parse_usercatoms(s: str): parser.add_argument("--max-steps", type=int, default=500) parser.add_argument("--ff-name", default="UFF") parser.add_argument("--verbose", action="store_true") + parser.add_argument("--ANN", action="store_true") parser.add_argument("--smart-generation", action="store_true", default=True) parser.add_argument("--no-smart-generation", dest="smart_generation", action="store_false") @@ -223,6 +230,7 @@ def _parse_usercatoms(s: str): pargs = parser.parse_args(subargv) + # Normalize lists: usercatoms/occupancies/isomers can be missing; create_ligand_list handles None ligands = pargs.ligands if pargs.usercatoms is not None: @@ -281,7 +289,7 @@ def _parse_usercatoms(s: str): vis_view = (22, -60) # Run build - mol, clash, severity, fig = generate_complex( + mol, clash, severity, fig, batslist, backbone_core_indices = generate_complex( ligand_list, metals=pargs.metal, voxel_size=pargs.voxel_size, @@ -310,6 +318,40 @@ def _parse_usercatoms(s: str): run_sterics=pargs.run_sterics, ) + dents = [] + for lig in ligand_list: + dents.append(len(lig[1])) + + # -------------------- ANN ------------------------ + ANN_bondl = None + if pargs.ANN == True: + metal = pargs.metal + ox = pargs.ox + spin = pargs.spin + ligands = pargs.ligands + occs = pargs.occupancies + dents = dents + batslist = batslist + tcats = [[],[],[],[],[],[]] + licores = getlicores() + geometry = pargs.geometry + + try: + ANN_flag, ANN_bondl, ANN_reason, ANN_attributes, catalysis_flag = enhanced_init_ANN(metal, ox, spin, ligands, occs, dents, + batslist, tcats, licores, geometry) + except: + print("ANN failed. Skipping...") + + # -------------------- Metal-Ligand Bond Distance ------------------------ + bondl = None + if ANN_bondl != None: + bondl = [] + for length in ANN_bondl: + bondl.append(length[1]) + + bondl = None + mol = enforce_metal_ligand_distances_and_optimize(mol, bondl, backbone_core_indices) + # -------------------- auto-build run name -------------------- import re, hashlib diff --git a/molSimplify/python_nn/tf_ANN.py b/molSimplify/python_nn/tf_ANN.py index 3203fbd9..c391da29 100644 --- a/molSimplify/python_nn/tf_ANN.py +++ b/molSimplify/python_nn/tf_ANN.py @@ -18,13 +18,127 @@ import scipy from typing import List, Tuple, Union, Optional from tensorflow.keras import backend as K -from tensorflow.keras.models import model_from_json, load_model from importlib_resources import files as resource_files from packaging import version import tensorflow as tf +import sys, types from molSimplify.python_nn.clf_analysis_tool import array_stack, get_layer_outputs, dist_neighbor, get_entropy +import os +from pathlib import Path +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' + +# --- Keras compat (no mypy redefs) --- +try: + from keras.models import load_model as _km_load_model, model_from_json as _km_model_from_json + from keras.optimizers import Adam as _km_Adam +except Exception: + from tensorflow.keras.models import load_model as _km_load_model, model_from_json as _km_model_from_json # type: ignore[no-redef] + from tensorflow.keras.optimizers import Adam as _km_Adam # type: ignore[no-redef] + +# Single public names used everywhere below (defined once) +load_model = _km_load_model # type: ignore[assignment] +model_from_json = _km_model_from_json # type: ignore[assignment] +Adam = _km_Adam # type: ignore[assignment] + +# --- importlib.resources compat --- +try: + from importlib.resources import files as resource_files # type: ignore[attr-defined] +except Exception: + try: + from importlib_resources import files as resource_files # type: ignore[no-redef] + except Exception: + resource_files = None # type: ignore[assignment] + + + +def get_layer_outputs2(model, layer_index, input_array, training_flag=False): + """ + Return the output of model.layers[layer_index] for a single example. + Ensures input has correct dtype/shape and matches the model's expected structure. + """ + # 1) Build a partial model from the original graph + try: + # Functional/Sequential both support these + partial_model = tf.keras.Model(inputs=model.inputs, + outputs=model.layers[layer_index].output) + except Exception: + # Fallback for very old Sequential APIs + partial_model = tf.keras.Model(inputs=model.input, + outputs=model.layers[layer_index].output) + + # 2) Prepare inputs (ndarray float32, correct rank) + x = _prepare_inputs_for_model(model, input_array) + + # 3) Call with the correct structure (no extra list wrappers, no strings) + out = partial_model(x, training=training_flag) + # Ensure eager value + return out.numpy() if hasattr(out, "numpy") else np.array(out) + + +def _prepare_inputs_for_model(model, X): + """Ensure numeric ndarray and correct rank for model.input(s).""" + X = np.asarray(X, dtype=np.float32) + X = _fix_shape_for_model(model, X) # your helper: adds (time) axis if needed + return X + + +def _install_sklearn_compat_shims(): + """ + Create backwards-compat import paths so old pickles load under modern sklearn. + Covers: + - sklearn.svm.classes -> sklearn.svm._classes + - sklearn.externals.joblib -> joblib + - sklearn.grid_search -> sklearn.model_selection + - sklearn.cross_validation -> sklearn.model_selection + - sklearn.utils.fixes (fallback to `types.SimpleNamespace`) if referenced + """ + # 1) svm.classes -> svm._classes + try: + import sklearn.svm._classes as _svm_new + mod_name = "sklearn.svm.classes" + if mod_name not in sys.modules: + shim = types.ModuleType(mod_name) + for name in ("SVC","NuSVC","SVR","NuSVR","OneClassSVM","LinearSVC","LinearSVR"): + if hasattr(_svm_new, name): + setattr(shim, name, getattr(_svm_new, name)) + sys.modules[mod_name] = shim + except Exception: + pass + + # 2) externals.joblib -> joblib + try: + import joblib as _jl + mod_name = "sklearn.externals.joblib" + if mod_name not in sys.modules: + sys.modules[mod_name] = _jl + except Exception: + pass + + # 3) grid_search -> model_selection + try: + import sklearn.model_selection as _ms + mod_name = "sklearn.grid_search" + if mod_name not in sys.modules: + sys.modules[mod_name] = _ms + except Exception: + pass + + # 4) cross_validation -> model_selection + try: + import sklearn.model_selection as _ms2 + mod_name = "sklearn.cross_validation" + if mod_name not in sys.modules: + sys.modules[mod_name] = _ms2 + except Exception: + pass + + # 5) utils.fixes sometimes referenced; provide a harmless placeholder + mod_name = "sklearn.utils.fixes" + if mod_name not in sys.modules: + sys.modules[mod_name] = types.ModuleType(mod_name) + def perform_ANN_prediction(RAC_dataframe: pd.DataFrame, predictor_name: str, RAC_column: str = 'RACs') -> pd.DataFrame: @@ -392,6 +506,7 @@ def load_keras_ann(predictor: str, suffix: str = 'model', compile: bool = False) # this function loads the ANN for property # "predictor" # disable TF output text to reduce console spam + os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' key = get_key(predictor, suffix) if "clf" not in predictor: @@ -438,6 +553,229 @@ def load_keras_ann(predictor: str, suffix: str = 'model', compile: bool = False) return loaded_model + +# --- sklearn loader + adapter --- +def _try_load_sklearn_classifier(predictor: str): + """ + If an sklearn model for the classifier exists under sklearn_models//, + load it with joblib and return an adapter that exposes predict/predict_proba. + Otherwise return None and let Keras fallback take over. + """ + import joblib + + here = Path(__file__).resolve() + # Allow override via env var if you want (optional) + base_override = os.environ.get("MOLSIMPLIFY_MODEL_DIR") # e.g., /path/to/models + candidates = [] + if base_override: + candidates.append(Path(base_override) / "sklearn_models" / predictor) + # common relative places + candidates += [ + here.parent / "tf_nn" / predictor, # molSimplify/python_nn/sklearn_models/geo_static_clf + here.parent.parent / "tf_nn" / predictor, # molSimplify/sklearn_models/geo_static_clf + here.parent.parent.parent / "tf_nn" / predictor, # repo_root/sklearn_models/geo_static_clf + ] + + filenames = ["model.joblib", "clf.joblib", "model.pkl", "clf.pkl"] + for d in candidates: + for fname in filenames: + p = d / fname + if p.exists(): + sk = joblib.load(p) + + class _SKAdapter: + def __init__(self, skobj, name): + self._sk = skobj + self.name = name + # mimic Keras .predict signature loosely + def predict(self, X, batch_size=None, verbose=None): + return self._sk.predict(X) + def predict_proba(self, X): + if hasattr(self._sk, "predict_proba"): + return self._sk.predict_proba(X) + raise AttributeError("This sklearn model has no predict_proba") + # handy for debugging/logs + def __repr__(self): + return f"" + + return _SKAdapter(sk, predictor) + return None + +def _fix_shape_for_model(model, X: np.ndarray) -> np.ndarray: + """ + Ensure X matches model.input_shape: + - If model expects (None, None, F) and X is (N, F) -> make (N, 1, F) + - If model expects (None, F) and X is (N, 1, F) -> squeeze to (N, F) + Returns float32 array. + """ + inp_shape = getattr(model, "input_shape", None) + if isinstance(inp_shape, list): # multi-input models → use first + inp_shape = inp_shape[0] + + if inp_shape is not None: + # Normalize negative/None dims away from batch dimension + if len(inp_shape) == 3: + # (batch, timesteps, features) + if X.ndim == 2: + X = np.expand_dims(X, axis=1) # (N, 1, F) + elif len(inp_shape) == 2: + # (batch, features) + if X.ndim == 3 and X.shape[1] == 1: + X = X[:, 0, :] # (N, F) + else: + # If model has no input_shape attr, fall back to “add a time axis if 2D” + if X.ndim == 2: + X = np.expand_dims(X, axis=1) + return X.astype(np.float32, copy=False) + + +def load_keras_ann2(predictor: str, suffix: str = "model", compile: bool = False): + """ + Loads legacy Keras (JSON+H5) or monolithic H5 models from anywhere under molSimplify/tf_nn/. + Searches recursively with glob so differing subdirectory layouts work automatically. + """ + import os + from pathlib import Path + import glob + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" + + # --- Keras imports --- + try: + try: + from keras.saving.legacy.model_config import model_from_json as _legacy_model_from_json + except Exception: + _legacy_model_from_json = None + except Exception: + _legacy_model_from_json = None + + here = Path(__file__).resolve() + tf_nn_root = here.parent.parent / "tf_nn" + + def _adam_compat(**kwargs): + lr = kwargs.pop("lr", None) + kwargs.pop("decay", None) + if lr is not None: + kwargs["learning_rate"] = lr + return Adam(**kwargs) + + def _glob_find(basename): + """ + Recursively search tf_nn/ for basename (case-sensitive). + Returns the first match as a Path or None. + """ + pattern = str(tf_nn_root / "**" / basename) + hits = glob.glob(pattern, recursive=True) + if hits: + return Path(hits[0]) + return None + + def _model_from_json_compat(model_json: str): + """ + Minimal, mypy-safe loader for legacy Keras JSONs. + - Prefer keras.saving.legacy when available + - Otherwise fall back to keras / tf.keras with a small custom_objects dict + """ + # 1) Try legacy loader (works without custom_objects) + try: + from keras.saving.legacy.model_config import model_from_json as _km_legacy_mfj # type: ignore[attr-defined] + return _km_legacy_mfj(model_json) + except Exception: + pass + + # 2) Try keras.* with explicit custom_objects + try: + from keras.models import model_from_json as _km_mfj, Sequential as _km_Sequential + from keras.layers import Dense as _km_Dense, Dropout as _km_Dropout, BatchNormalization as _km_BN, Activation as _km_Act + from keras.initializers import VarianceScaling as _km_VS, Zeros as _km_Z, Ones as _km_O + from keras.regularizers import L1L2 as _km_L1L2 + _custom_objects = { + "Sequential": _km_Sequential, + "Dense": _km_Dense, + "Dropout": _km_Dropout, + "BatchNormalization": _km_BN, + "Activation": _km_Act, + "VarianceScaling": _km_VS, + "Zeros": _km_Z, + "Ones": _km_O, + "L1L2": _km_L1L2, + } + return _km_mfj(model_json, custom_objects=_custom_objects) + except Exception: + pass + + # 3) Fallback to tf.keras.* with explicit custom_objects + from tensorflow.keras.models import model_from_json as _tk_mfj, Sequential as _tk_Sequential + from tensorflow.keras.layers import Dense as _tk_Dense, Dropout as _tk_Dropout, BatchNormalization as _tk_BN, Activation as _tk_Act + from tensorflow.keras.initializers import VarianceScaling as _tk_VS, Zeros as _tk_Z, Ones as _tk_O + from tensorflow.keras.regularizers import L1L2 as _tk_L1L2 + _custom_objects = { + "Sequential": _tk_Sequential, + "Dense": _tk_Dense, + "Dropout": _tk_Dropout, + "BatchNormalization": _tk_BN, + "Activation": _tk_Act, + "VarianceScaling": _tk_VS, + "Zeros": _tk_Z, + "Ones": _tk_O, + "L1L2": _tk_L1L2, + } + return _tk_mfj(model_json, custom_objects=_custom_objects) + + # -------- actual loading -------- + if "clf" not in predictor: + json_name = f"{predictor}_{suffix}.json" + h5_name = f"{predictor}_{suffix}.h5" + + json_path = _glob_find(json_name) + h5_path = _glob_find(h5_name) + if json_path is None: + raise FileNotFoundError(f"{json_name}") + if h5_path is None: + raise FileNotFoundError(f"{h5_name}") + + model_json = json_path.read_text() + model = _model_from_json_compat(model_json) + model.load_weights(h5_path) + else: + # classifier = single .h5 file + h5_name = f"{predictor}_{suffix}.h5" + h5_path = _glob_find(h5_name) + if h5_path is None: + raise FileNotFoundError(h5_name) + model = load_model(h5_path, compile=False) + + # -------- optional compile -------- + if compile: + if predictor == "homo": + opt = _adam_compat(beta_2=1 - 0.0016204733101599046, + beta_1=0.8718839135783554, + lr=0.0004961686075897741) + model.compile(loss="mse", optimizer=opt, metrics=["mse", "mae", "mape"]) + elif predictor == "gap": + opt = _adam_compat(beta_2=1 - 0.00010929248596488832, + beta_1=0.8406735969305784, + lr=0.0006759924688701965) + model.compile(loss="mse", optimizer=opt, metrics=["mse", "mae", "mape"]) + elif predictor in ["oxo", "hat", "oxo20"]: + opt = _adam_compat(lr=0.0012838133056087084, + beta_1=0.9811686522122317, + beta_2=0.8264616523572279) + model.compile(loss="mse", optimizer=opt, metrics=["mse", "mae", "mape"]) + elif predictor == "homo_empty": + opt = _adam_compat(lr=0.006677578283098809, + beta_1=0.8556594887870226, + beta_2=0.9463468021275508) + model.compile(loss="mse", optimizer=opt, metrics=["mse", "mae", "mape"]) + elif predictor in ["geo_static_clf", "sc_static_clf"]: + opt = _adam_compat(lr=5e-5, beta_1=0.95, amsgrad=True) + model.compile(loss="binary_crossentropy", optimizer=opt, metrics=["accuracy"]) + else: + model.compile(loss="mse", optimizer=Adam(), metrics=["mse", "mae", "mape"]) + + return model + + + def tf_ANN_excitation_prepare(predictor: str, descriptors: List[float], descriptor_names: List[str]) -> np.ndarray: ## this function reforms the provided list of descriptors and their ## names to match the expectations of the target ANN model. @@ -480,8 +818,14 @@ def ANN_supervisor(predictor: str, excitation = data_normalize(excitation, train_mean_x, train_var_x, debug=debug) ## fetch ANN - loaded_model = load_keras_ann(predictor) - result = data_rescale(loaded_model.predict(excitation, verbose=0), train_mean_y, train_var_y, debug=debug) + try: + loaded_model = load_keras_ann(predictor) + result = data_rescale(loaded_model.predict(excitation, verbose=0), train_mean_y, train_var_y, debug=debug) + except: + loaded_model = load_keras_ann2(predictor) + excitation = _fix_shape_for_model(loaded_model, excitation) + result = data_rescale(loaded_model.predict(excitation, verbose=0), train_mean_y, train_var_y, debug=debug) + if "clf" not in predictor: if debug: print(f'LOADED MODEL HAS {len(loaded_model.layers)} layers, so latent space measure will be from first {len(loaded_model.layers) - 1} layers') @@ -603,7 +947,11 @@ def find_ANN_latent_dist(predictor, latent_space_vector, debug=False): min_dist = 100000000 min_ind = 0 - loaded_model = load_keras_ann(predictor) + try: + loaded_model = load_keras_ann(predictor) + except: + loaded_model = load_keras_ann2(predictor) + if debug: print('measuring latent distances:') @@ -618,8 +966,12 @@ def find_ANN_latent_dist(predictor, latent_space_vector, debug=False): scaled_row = np.squeeze( data_normalize(rows, train_mean_x.T, train_var_x.T, debug=debug)) # Normalizing the row before finding the distance if version.parse(tf.__version__) >= version.parse('2.0.0'): - latent_train_row = get_layer_outputs(loaded_model, len(loaded_model.layers) - 2, - [np.array([scaled_row])], training_flag=False) + try: + latent_train_row = get_layer_outputs(loaded_model, len(loaded_model.layers) - 2, + [np.array([scaled_row])], training_flag=False) + except: + latent_train_row = get_layer_outputs2(loaded_model, len(loaded_model.layers) - 2, + [np.array([scaled_row])], training_flag=False) else: latent_train_row = get_outputs([np.array([scaled_row]), 0]) this_dist = np.linalg.norm(np.subtract(np.squeeze(latent_train_row), np.squeeze(latent_space_vector)))