Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 17 additions & 14 deletions ffn/inference/inference_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from . import storage


# Names retained for compatibility with the MR interface.
# TODO(mjanusz): Drop this requirement or provide a wrapper class for the MR
# counter so that this is no longer necessary.
Expand All @@ -35,11 +36,11 @@ def __init__(self, update, name, parent=None):
"""Initializes the counter.

Args:
update: callable taking no arguments; will be called when
the counter is incremented
update: callable taking no arguments; will be called when the counter is
incremented
name: name of the counter to use for streamz
parent: optional StatCounter object to which to propagate
any updates of the current counter
parent: optional StatCounter object to which to propagate any updates of
the current counter
"""
self._counter = 0
self._update = update
Expand Down Expand Up @@ -80,6 +81,7 @@ def __repr__(self):
def value(self):
return self._counter


# pylint: enable=invalid-name

MSEC_IN_SEC = 1000
Expand Down Expand Up @@ -137,8 +139,7 @@ def dump(self, filename: str):
fd.write('%s: %d\n' % (name, counter.value))

def dumps(self) -> str:
state = {name: counter.value for name, counter in
self._counters.items()}
state = {name: counter.value for name, counter in self._counters.items()}
return json.dumps(state)

def loads(self, encoded_state: str):
Expand All @@ -150,13 +151,16 @@ def loads(self, encoded_state: str):


@contextlib.contextmanager
def timer_counter(counters: Counters, name: str, export=True):
def timer_counter(
counters: Counters, name: str, export=True, increment: int = 1
):
"""Creates a counter tracking time spent in the context.

Args:
counters: Counters object
name: counter name
export: whether to export counter via streamz
increment: value by which to increment the underlying counter

Yields:
tuple of two counters, to track number of calls and time spent on them,
Expand All @@ -169,7 +173,7 @@ def timer_counter(counters: Counters, name: str, export=True):
try:
yield timer, counter
finally:
counter.Increment()
counter.IncrementBy(increment)
dt = (time.time() - start_time) * MSEC_IN_SEC
timer.IncrementBy(dt)

Expand Down Expand Up @@ -204,9 +208,8 @@ def match_histogram(image, lut, mask=None):
Args:
image: (z, y, x) ndarray with the source image
lut: lookup table from `compute_histogram_lut`
mask: optional Boolean mask defining areas that
are NOT to be considered for CDF calculation
after applying CLAHE
mask: optional Boolean mask defining areas that are NOT to be considered for
CDF calculation after applying CLAHE

Returns:
None; `image` is modified in place
Expand All @@ -222,12 +225,12 @@ def match_histogram(image, lut, mask=None):
if valid_slice.size == 0:
continue

cdf, bins = skimage.exposure.cumulative_distribution(
valid_slice)
cdf, bins = skimage.exposure.cumulative_distribution(valid_slice)
cdf = np.array(cdf.tolist() + [1.0])
bins = np.array(bins.tolist() + [255])
image[z, ...] = lut[
(cdf[np.searchsorted(bins, clahe_slice)] * 255).astype(np.uint8)]
(cdf[np.searchsorted(bins, clahe_slice)] * 255).astype(np.uint8)
]


def compute_histogram_lut(image):
Expand Down
Loading