Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions gin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from gin.config import exit_interactive_mode
from gin.config import external_configurable
from gin.config import finalize
from gin.config import get_bindings
from gin.config import operative_config_str
from gin.config import parse_config
from gin.config import parse_config_file
Expand Down
95 changes: 78 additions & 17 deletions gin/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def drink(cocktail):
import sys
import threading
import traceback
from typing import Optional, Sequence, Union
from typing import Any, Callable, Dict, Optional, Sequence, Type, Union

from gin import config_parser
from gin import selector_map
Expand Down Expand Up @@ -139,6 +139,8 @@ def exit_scope(self):

# Maintains the registry of configurable functions and classes.
_REGISTRY = selector_map.SelectorMap()
# Inverse registery to recover a binding from a function or class
_FN_OR_CLS_TO_SELECTOR = {}

# Maps tuples of `(scope, selector)` to associated parameter values. This
# specifies the current global "configuration" set through `bind_parameter` or
Expand Down Expand Up @@ -983,6 +985,51 @@ def load_eval_data():
_SCOPE_MANAGER.exit_scope()


def get_bindings(
fn_or_cls: Union[str, Callable[..., Any], Type[Any]],
) -> Dict[str, Any]:
"""Returns the bindings associated with the given configurable.

Example:

```python
config.parse_config('MyParams.kwarg0 = 123')

gin.get_bindings('MyParams') == {'kwarg0': 123}
```

Note: The scope in which `get_bindings` is called will be used.

Args:
fn_or_cls: Configurable function, class or selector `str` too.

Returns:
The bindings kwargs injected by gin.
"""
if isinstance(fn_or_cls, str):
# Resolve partial selector -> full selector
selector = _REGISTRY.get_match(fn_or_cls)
if selector:
selector = selector.selector
else:
selector = _FN_OR_CLS_TO_SELECTOR.get(fn_or_cls)

if selector is None:
raise ValueError(f'Could not find {fn_or_cls} in the gin register.')

return _get_bindings(selector)


def _get_bindings(selector: str) -> Dict[str, Any]:
"""Returns the bindings for the current full selector."""
scope_components = current_scope()
new_kwargs = {}
for i in range(len(scope_components) + 1):
partial_scope_str = '/'.join(scope_components[:i])
new_kwargs.update(_CONFIG.get((partial_scope_str, selector), {}))
return new_kwargs


def _make_gin_wrapper(fn, fn_or_cls, name, selector, allowlist, denylist):
"""Creates the final Gin wrapper for the given function.

Expand Down Expand Up @@ -1015,13 +1062,9 @@ def _make_gin_wrapper(fn, fn_or_cls, name, selector, allowlist, denylist):
@functools.wraps(fn)
def gin_wrapper(*args, **kwargs):
"""Supplies fn with parameter values from the configuration."""
scope_components = current_scope()
new_kwargs = {}
for i in range(len(scope_components) + 1):
partial_scope_str = '/'.join(scope_components[:i])
new_kwargs.update(_CONFIG.get((partial_scope_str, selector), {}))
new_kwargs = _get_bindings(selector)
gin_bound_args = list(new_kwargs.keys())
scope_str = partial_scope_str
scope_str = '/'.join(current_scope())

arg_names = _get_supplied_positional_parameter_names(signature_fn, args)

Expand Down Expand Up @@ -1147,6 +1190,27 @@ def gin_wrapper(*args, **kwargs):
return gin_wrapper


def _make_selector(
fn_or_cls,
*,
name: Optional[str],
module: Optional[str],
) -> str:
"""Returns the gin name selector."""
name = fn_or_cls.__name__ if name is None else name
if config_parser.IDENTIFIER_RE.match(name):
default_module = getattr(fn_or_cls, '__module__', None)
module = default_module if module is None else module
elif not config_parser.MODULE_RE.match(name):
raise ValueError("Configurable name '{}' is invalid.".format(name))

if module is not None and not config_parser.MODULE_RE.match(module):
raise ValueError("Module '{}' is invalid.".format(module))

selector = module + '.' + name if module else name
return selector


def _make_configurable(fn_or_cls,
name=None,
module=None,
Expand Down Expand Up @@ -1188,17 +1252,12 @@ def _make_configurable(fn_or_cls,
err_str = 'Attempted to add a new configurable after the config was locked.'
raise RuntimeError(err_str)

name = fn_or_cls.__name__ if name is None else name
if config_parser.IDENTIFIER_RE.match(name):
default_module = getattr(fn_or_cls, '__module__', None)
module = default_module if module is None else module
elif not config_parser.MODULE_RE.match(name):
raise ValueError("Configurable name '{}' is invalid.".format(name))

if module is not None and not config_parser.MODULE_RE.match(module):
raise ValueError("Module '{}' is invalid.".format(module))
selector = _make_selector(
fn_or_cls,
name=name,
module=module,
)

selector = module + '.' + name if module else name
if not _INTERACTIVE_MODE and selector in _REGISTRY:
err_str = ("A configurable matching '{}' already exists.\n\n"
'To allow re-registration of configurables in an interactive '
Expand Down Expand Up @@ -1234,6 +1293,8 @@ def decorator(fn):
allowlist=allowlist,
denylist=denylist,
selector=selector)
# Inverse registery
_FN_OR_CLS_TO_SELECTOR[decorated_fn_or_cls] = selector
return decorated_fn_or_cls


Expand Down
68 changes: 68 additions & 0 deletions tests/config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1985,6 +1985,74 @@ def testEmptyNestedIncludesAndImports(self):
[], ['TEST=1'], print_includes_and_imports=True)
self.assertListEqual(result, [])

def testGetBindings(self):
# Bindings can be accessed through name or object
# Default are empty
self.assertDictEqual(config.get_bindings('configurable1'), {})
self.assertDictEqual(config.get_bindings(fn1), {})

self.assertDictEqual(config.get_bindings('ConfigurableClass'), {})
self.assertDictEqual(config.get_bindings(ConfigurableClass), {})

config_str = """
configurable1.non_kwarg = 'kwarg1'
configurable1.kwarg2 = 123
ConfigurableClass.kwarg1 = 'okie dokie'
"""
config.parse_config(config_str)

self.assertDictEqual(config.get_bindings('configurable1'), {
'non_kwarg': 'kwarg1',
'kwarg2': 123,
})
self.assertDictEqual(config.get_bindings(fn1), {
'non_kwarg': 'kwarg1',
'kwarg2': 123,
})

self.assertDictEqual(config.get_bindings('ConfigurableClass'), {
'kwarg1': 'okie dokie',
})
self.assertDictEqual(config.get_bindings(ConfigurableClass), {
'kwarg1': 'okie dokie',
})

def testGetBindingsScope(self):
config_str = """
configurable1.non_kwarg = 'kwarg1'
configurable1.kwarg2 = 123
scope/configurable1.kwarg2 = 456
"""
config.parse_config(config_str)

self.assertDictEqual(config.get_bindings('configurable1'), {
'non_kwarg': 'kwarg1',
'kwarg2': 123,
})
self.assertDictEqual(config.get_bindings(fn1), {
'non_kwarg': 'kwarg1',
'kwarg2': 123,
})

with config.config_scope('scope'):
self.assertDictEqual(config.get_bindings('configurable1'), {
'non_kwarg': 'kwarg1',
'kwarg2': 456,
})
self.assertDictEqual(config.get_bindings(fn1), {
'non_kwarg': 'kwarg1',
'kwarg2': 456,
})

def testGetBindingsUnknown(self):

expected_msg = 'Could not find .* in the gin register'
with self.assertRaisesRegex(ValueError, expected_msg):
config.get_bindings('UnknownParam')

with self.assertRaisesRegex(ValueError, expected_msg):
config.get_bindings(lambda x: None)


if __name__ == '__main__':
absltest.main()