diff --git a/src/deprotecting_group.py b/src/deprotecting_group.py deleted file mode 100644 index 0a09569..0000000 --- a/src/deprotecting_group.py +++ /dev/null @@ -1,151 +0,0 @@ -from rdkit import Chem -import re - -# These are the protecting groups that we want to unmask (reverse of PG_MAP) -DEPROTECT_MAP = { - "$": "OC", # OMe - "%": "COCc1ccccc1", # OBn - "&": "COC", # OEt -} - - -def unmask_protecting_groups_multisymbol(smiles: str) -> str: - """ - Replace single-character symbols in a SMILES string with their original protecting groups. - - This function reverses the masking process by converting simplified single-character - symbols back to their original protecting group SMILES representations. This is - useful for converting masked molecules back to their full chemical representations - after retrosynthetic analysis. - - Parameters - ---------- - smiles : str - Input SMILES string containing masked protecting group symbols. Can be any - valid SMILES notation with symbols '$', '%', '&' representing protecting groups. - - Returns - ------- - str - Modified SMILES string with symbols replaced by original protecting groups: - - - '$' → 'OC' (OMe) - - '%' → 'COCc1ccccc1' (OBn) - - '&' → 'COC' (OEt) - - Returns 'INVALID_SMILES' if the input cannot be parsed by RDKit. - Returns empty string if the input is an empty string. - - Examples - -------- - >>> unmask_protecting_groups_multisymbol("&") - 'COC' - - >>> unmask_protecting_groups_multisymbol("%") - 'COCc1ccccc1' - - >>> unmask_protecting_groups_multisymbol("CC(C)%") - 'CC(C)COCc1ccccc1' - - >>> unmask_protecting_groups_multisymbol("&.&") - 'COC.COC' - - >>> unmask_protecting_groups_multisymbol("CC(C)$%&") - 'CC(C)OCCOCc1ccccc1COC' - - >>> unmask_protecting_groups_multisymbol("invalid_smiles") - 'INVALID_SMILES' - - >>> unmask_protecting_groups_multisymbol("") - '' - - References - ---------- - .. [1] Greene, T.W. and Wuts, P.G.M. (2006) Protective Groups in Organic Synthesis. - 4th Edition, John Wiley & Sons, Inc., Hoboken. - """ - if not smiles: - return "" - - # Handle special case for "INVALID_SMILES" - if smiles == "INVALID_SMILES": - return "INVALID_SMILES" - - # First, replace all symbols with their corresponding protecting groups - unmasked_smiles = smiles - for symbol, pg_smiles in DEPROTECT_MAP.items(): - unmasked_smiles = unmasked_smiles.replace(symbol, pg_smiles) - - # Validate that the resulting SMILES is valid - mol = Chem.MolFromSmiles(unmasked_smiles) - if mol is None: - return "INVALID_SMILES" - - # Return canonical form of the unmasked SMILES - return Chem.MolToSmiles(mol, canonical=True) - - -def get_protecting_group_info() -> dict: - """ - Get information about the protecting groups used in the masking/unmasking process. - - Returns - ------- - dict - Dictionary containing information about each protecting group: - - symbol: The single-character symbol used for masking - - smiles: The SMILES representation of the protecting group - - name: The common name of the protecting group - - description: Brief description of the protecting group - """ - return { - "$": { - "smiles": "OC", - "name": "OMe", - "description": "Methoxy protecting group" - }, - "%": { - "smiles": "COCc1ccccc1", - "name": "OBn", - "description": "Benzyl protecting group" - }, - "&": { - "smiles": "COC", - "name": "OEt", - "description": "Ethoxy protecting group" - } - } - - -def validate_masked_smiles(smiles: str) -> bool: - """ - Validate if a SMILES string contains valid masking symbols. - - Parameters - ---------- - smiles : str - Input SMILES string to validate - - Returns - ------- - bool - True if the SMILES contains valid masking symbols ('$', '%', '&') - and other valid SMILES characters, False otherwise - """ - if not smiles: - return True - - # Check if the SMILES contains any of our masking symbols - has_masking_symbols = any(symbol in smiles - for symbol in DEPROTECT_MAP.keys()) - - if not has_masking_symbols: - return False - - # Define valid characters for masked SMILES (including our symbols) - valid_chars = set("$%&") | set( - "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789()[]{}@+-=#$%&*./\\" - ) - - # Check if all characters in the SMILES are valid - return all(char in valid_chars for char in smiles) diff --git a/src/utils/llm.py b/src/utils/llm.py index 793393a..fab1a24 100644 --- a/src/utils/llm.py +++ b/src/utils/llm.py @@ -17,7 +17,6 @@ from src.utils.stability_checks import stability_checker from src.utils.hallucination_checks import hallucination_checker from src.protecting_group import mask_protecting_groups_multisymbol -from src.deprotecting_group import unmask_protecting_groups_multisymbol, get_protecting_group_info load_dotenv() diff --git a/test_deprotecting_group.py b/test_deprotecting_group.py deleted file mode 100644 index 68221ac..0000000 --- a/test_deprotecting_group.py +++ /dev/null @@ -1,317 +0,0 @@ -#!/usr/bin/env python3 -""" -Comprehensive test suite for the deprotecting group module. -""" - -import unittest -from src.deprotecting_group import (unmask_protecting_groups_multisymbol, - get_protecting_group_info, - validate_masked_smiles, DEPROTECT_MAP) - - -class TestDeprotectingGroup(unittest.TestCase): - """Test cases for the deprotecting group module.""" - - def setUp(self): - """Set up test fixtures.""" - self.test_cases = { - # Simple single protecting groups - "$": "CO", # OMe (canonicalized) - "%": "COCc1ccccc1", # OBn - "&": "COC", # OEt - - # Molecules with single protecting groups - "CC(C)$": "COC(C)C", # Canonicalized - "CC(C)%": "CC(C)COCc1ccccc1", - "CC(C)&": "COCC(C)C", # Canonicalized - - # Molecules with multiple protecting groups - "CC(C)$%": "CC(C)OCCOCc1ccccc1", # Not canonicalized in this case - "CC(C)$&": "COCCOC(C)C", # Canonicalized - "CC(C)%&": "COCc1ccccc1COCC(C)C", # Canonicalized - "CC(C)$%&": "COCc1ccccc1COCCOC(C)C", # Canonicalized - - # Complex molecules - "CC(C)(C)$": "COC(C)(C)C", # Canonicalized - "c1ccccc1%": "c1ccc(COCc2ccccc2)cc1", # Canonicalized - "CCO&": "CCOCOC", - - # Molecules with disconnections (salts) - "$.%": "CO.COCc1ccccc1", # Canonicalized - "&.$": "CO.COC", # Canonicalized - "$%&": "COCc1ccccc1COCCO", # Canonicalized - - # Edge cases - "": "", - "INVALID_SMILES": "INVALID_SMILES", - } - - def test_unmask_single_protecting_groups(self): - """Test unmasking of single protecting group symbols.""" - for masked, expected in self.test_cases.items(): - if masked in ["", "INVALID_SMILES"]: - continue - with self.subTest(masked=masked): - result = unmask_protecting_groups_multisymbol(masked) - self.assertEqual(result, expected) - - def test_unmask_empty_string(self): - """Test unmasking of empty string.""" - result = unmask_protecting_groups_multisymbol("") - self.assertEqual(result, "") - - def test_unmask_invalid_smiles(self): - """Test unmasking of 'INVALID_SMILES' string.""" - result = unmask_protecting_groups_multisymbol("INVALID_SMILES") - self.assertEqual(result, "INVALID_SMILES") - - def test_unmask_no_symbols(self): - """Test unmasking of SMILES without protecting group symbols.""" - test_cases = [ - "CC(C)", - "c1ccccc1", - "CCO", - "CC(C)(C)C", - ] - - for smiles in test_cases: - with self.subTest(smiles=smiles): - result = unmask_protecting_groups_multisymbol(smiles) - # Should return canonicalized SMILES since no symbols were present - # The canonicalization might change the SMILES representation - self.assertNotEqual(result, "INVALID_SMILES") - self.assertIsInstance(result, str) - - def test_unmask_invalid_symbols(self): - """Test unmasking of SMILES with invalid symbols.""" - invalid_cases = [ - "CC(C)#", # Invalid symbol - "CC(C)@", # Invalid symbol - "CC(C)!", # Invalid symbol - ] - - for smiles in invalid_cases: - with self.subTest(smiles=smiles): - result = unmask_protecting_groups_multisymbol(smiles) - # Should return INVALID_SMILES since these are invalid SMILES - self.assertEqual(result, "INVALID_SMILES") - - def test_unmask_complex_molecules(self): - """Test unmasking of complex molecules with multiple protecting groups.""" - complex_cases = { - "CC(C)$%&": "COCc1ccccc1COCCOC(C)C", # Canonicalized - "c1ccccc1$%": "c1ccc(COCCOc2ccccc2)cc1", # Canonicalized - "CCO$%&": "CCOOCCOCc1ccccc1COC", # Canonicalized - } - - for masked, expected in complex_cases.items(): - with self.subTest(masked=masked): - result = unmask_protecting_groups_multisymbol(masked) - self.assertEqual(result, expected) - - def test_unmask_disconnected_molecules(self): - """Test unmasking of disconnected molecules (salts).""" - salt_cases = { - "$.%": "CO.COCc1ccccc1", # Canonicalized - "&.$": "CO.COC", # Canonicalized - "%.&": "COC.COCc1ccccc1", # Canonicalized - } - - for masked, expected in salt_cases.items(): - with self.subTest(masked=masked): - result = unmask_protecting_groups_multisymbol(masked) - self.assertEqual(result, expected) - - def test_unmask_symbol_order_independence(self): - """Test that unmasking works regardless of symbol order in DEPROTECT_MAP.""" - # Test that the order of replacement doesn't matter - test_smiles = "CC(C)$%&" - - # Get the expected result - expected = unmask_protecting_groups_multisymbol(test_smiles) - - # Verify that the result is consistent - for _ in range(5): # Test multiple times - result = unmask_protecting_groups_multisymbol(test_smiles) - self.assertEqual(result, expected) - - -class TestProtectingGroupInfo(unittest.TestCase): - """Test cases for the protecting group information functions.""" - - def test_get_protecting_group_info(self): - """Test the get_protecting_group_info function.""" - info = get_protecting_group_info() - - # Check that all expected symbols are present - expected_symbols = {"$", "%", "&"} - self.assertEqual(set(info.keys()), expected_symbols) - - # Check structure of each entry - for symbol in expected_symbols: - self.assertIn(symbol, info) - entry = info[symbol] - - # Check required fields - required_fields = {"smiles", "name", "description"} - self.assertEqual(set(entry.keys()), required_fields) - - # Check data types - self.assertIsInstance(entry["smiles"], str) - self.assertIsInstance(entry["name"], str) - self.assertIsInstance(entry["description"], str) - - # Check that SMILES matches DEPROTECT_MAP - self.assertEqual(entry["smiles"], DEPROTECT_MAP[symbol]) - - def test_protecting_group_info_content(self): - """Test the specific content of protecting group information.""" - info = get_protecting_group_info() - - # Test OMe ($) - self.assertEqual(info["$"]["name"], "OMe") - self.assertEqual(info["$"]["smiles"], "OC") - self.assertEqual(info["$"]["description"], "Methoxy protecting group") - - # Test OBn (%) - self.assertEqual(info["%"]["name"], "OBn") - self.assertEqual(info["%"]["smiles"], "COCc1ccccc1") - self.assertEqual(info["%"]["description"], "Benzyl protecting group") - - # Test OEt (&) - self.assertEqual(info["&"]["name"], "OEt") - self.assertEqual(info["&"]["smiles"], "COC") - self.assertEqual(info["&"]["description"], "Ethoxy protecting group") - - -class TestValidation(unittest.TestCase): - """Test cases for the validation function.""" - - def test_validate_masked_smiles_valid_cases(self): - """Test validation of valid masked SMILES.""" - valid_cases = [ - "$", # Single OMe symbol - "%", # Single OBn symbol - "&", # Single OEt symbol - "CC(C)$", # Molecule with OMe - "CC(C)%", # Molecule with OBn - "CC(C)&", # Molecule with OEt - "CC(C)$%", # Molecule with multiple symbols - "CC(C)$%&", # Molecule with all symbols - "c1ccccc1$", # Aromatic with symbol - "CCO$%", # Alcohol with symbols - ] - - for smiles in valid_cases: - with self.subTest(smiles=smiles): - result = validate_masked_smiles(smiles) - self.assertTrue(result, f"Expected {smiles} to be valid") - - def test_validate_masked_smiles_invalid_cases(self): - """Test validation of invalid masked SMILES.""" - invalid_cases = [ - "CC(C)", # No symbols - "c1ccccc1", # No symbols - "CCO", # No symbols - "invalid", # Invalid characters - "CC(C)#", # Invalid symbol - "CC(C)@", # Invalid symbol - ] - - for smiles in invalid_cases: - with self.subTest(smiles=smiles): - result = validate_masked_smiles(smiles) - self.assertFalse(result, f"Expected {smiles} to be invalid") - - def test_validate_masked_smiles_edge_cases(self): - """Test validation of edge cases.""" - edge_cases = [ - ("", True), # Empty string should be valid - ("$$$", True), # Multiple symbols should be valid - ("%%%", True), # Multiple symbols should be valid - ("&&&", True), # Multiple symbols should be valid - ("$%&", True), # All symbols should be valid - ] - - for smiles, expected in edge_cases: - with self.subTest(smiles=smiles): - result = validate_masked_smiles(smiles) - self.assertEqual(result, expected, - f"Expected {smiles} to be {expected}") - - -class TestIntegration(unittest.TestCase): - """Integration tests for the deprotecting group module.""" - - def test_deprotect_map_consistency(self): - """Test that DEPROTECT_MAP is consistent with get_protecting_group_info.""" - info = get_protecting_group_info() - - for symbol, smiles in DEPROTECT_MAP.items(): - with self.subTest(symbol=symbol): - self.assertIn(symbol, info) - self.assertEqual(info[symbol]["smiles"], smiles) - - def test_unmask_validation_consistency(self): - """Test that unmasking works correctly with validated SMILES.""" - # Test that valid masked SMILES can be unmasked - valid_masked = ["$", "%", "&", "CC(C)$", "CC(C)%&"] - - for masked in valid_masked: - with self.subTest(masked=masked): - # Should be valid - self.assertTrue(validate_masked_smiles(masked)) - - # Should be unmaskable - result = unmask_protecting_groups_multisymbol(masked) - self.assertNotEqual(result, "INVALID_SMILES") - - def test_error_handling(self): - """Test error handling for various edge cases.""" - error_cases = [ - "INVALID_SMILES", # Special case - "$$$$$$", # Many symbols - "CC(C)$%&$%&", # Repeated symbols - ] - - for case in error_cases: - with self.subTest(case=case): - # Should handle other cases without crashing - result = unmask_protecting_groups_multisymbol(case) - self.assertIsInstance(result, str) - - def test_none_input(self): - """Test handling of None input.""" - # The function should handle None gracefully by returning empty string - # since None is falsy and the function checks "if not smiles:" - result = unmask_protecting_groups_multisymbol(None) - self.assertEqual(result, "") - - -def run_performance_test(): - """Run a simple performance test.""" - import time - - test_smiles = "CC(C)$%&" * 1000 # Create a long string - - start_time = time.time() - for _ in range(1000): - unmask_protecting_groups_multisymbol(test_smiles) - end_time = time.time() - - print( - f"Performance test: 1000 unmaskings took {end_time - start_time:.4f} seconds" - ) - - -if __name__ == "__main__": - # Run the unit tests - unittest.main(verbosity=2, exit=False) - - # Run performance test - print("\n" + "=" * 60) - print("Running performance test...") - run_performance_test() - - print("\n" + "=" * 60) - print("All tests completed!")