Skip to content

Commit

Permalink
[r0.26.0 Cherry-pick] Automatic autoreload of underlying modules a si…
Browse files Browse the repository at this point in the history
…ngle `_ModuleFinder` registered per module (#3439)

* fixes autoreload of underlying modules when used with tfx Interactive context

PiperOrigin-RevId: 358839958

* automatic autoreload of underlying modules
a single _ModuleFinder registered per module

PiperOrigin-RevId: 361515076

* Make _TfxModuleFinder thread safe by reusing existing lock, and remove an unnecessary variable

PiperOrigin-RevId: 363987501

Co-authored-by: tfx-team <tensorflow-extended-nonhuman@googlegroups.com>
Co-authored-by: zhitaoli <zhitaoli@google.com>
  • Loading branch information
3 people authored Mar 23, 2021
1 parent 629bea1 commit a960440
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 19 deletions.
86 changes: 69 additions & 17 deletions tfx/utils/import_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# Lint as: python2, python3
# Copyright 2019 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -26,9 +25,6 @@
from absl import logging
from tfx.utils import io_utils

_imported_modules_from_source = {}
_imported_modules_from_source_lock = threading.Lock()


def import_class_by_path(class_path: Text) -> Type[Any]:
"""Import a class by its <module>.<name> path.
Expand All @@ -45,6 +41,63 @@ def import_class_by_path(class_path: Text) -> Type[Any]:
return getattr(mod, classname)


def import_func_from_module(module_path: Text, fn_name: Text) -> Callable: # pylint: disable=g-bare-generic
"""Imports a function from a module provided as source file or module path."""
user_module = importlib.import_module(module_path)
return getattr(user_module, fn_name)


# This lock is used both for access to class variables for _TfxModuleFinder
# and usage of that class, therefore must be RLock
# to avoid deadlock among multiple levels of call stack.
_imported_modules_from_source_lock = threading.RLock()


class _TfxModuleFinder(importlib.abc.MetaPathFinder):
"""Registers custom modules for Interactive Context."""

_modules = {} # a mapping fullname -> source_path

def find_spec(self, fullname, path, target=None):
del path
del target
with _imported_modules_from_source_lock:
if fullname not in self._modules:
return None
source_path = self._modules[fullname]
loader = importlib.machinery.SourceFileLoader(
fullname=fullname, path=source_path)
return importlib.util.spec_from_loader(
fullname, loader, origin=source_path)

def find_module(self, fullname, path):
pass

def register_module(self, fullname, path):
"""Registers and imports a new module."""
with _imported_modules_from_source_lock:
if fullname in self._modules:
raise ValueError('Module %s is already registered' % fullname)
self._modules[fullname] = path

def get_module_name_by_path(self, path):
with _imported_modules_from_source_lock:
for module_name, source_path in self._modules.items():
if source_path == path:
return module_name
return None

@property
def count_registered(self):
with _imported_modules_from_source_lock:
return len(self._modules)


_tfx_module_finder = _TfxModuleFinder()
with _imported_modules_from_source_lock:
sys.meta_path.append(_tfx_module_finder)


# TODO(b/175174419): Revisit the workaround for multiple invocations of
# import_func_from_source.
def import_func_from_source(source_path: Text, fn_name: Text) -> Callable: # pylint: disable=g-bare-generic
Expand All @@ -54,12 +107,14 @@ def import_func_from_source(source_path: Text, fn_name: Text) -> Callable: # py
# because importlib can't import from GCS
source_path = io_utils.ensure_local(source_path)

module = None
with _imported_modules_from_source_lock:
if source_path not in _imported_modules_from_source:
logging.info('Loading %s because it has not been loaded before.',
source_path)
if _tfx_module_finder.get_module_name_by_path(source_path) is None:
# Create a unique module name.
module_name = 'user_module_%d' % len(_imported_modules_from_source)
module_name = 'user_module_%d' % _tfx_module_finder.count_registered
logging.info(
'Loading source_path %s as name %s '
'because it has not been loaded before.', source_path, module_name)
try:
loader = importlib.machinery.SourceFileLoader(
fullname=module_name,
Expand All @@ -70,17 +125,14 @@ def import_func_from_source(source_path: Text, fn_name: Text) -> Callable: # py
module = importlib.util.module_from_spec(spec)
sys.modules[loader.name] = module
loader.exec_module(module)
_imported_modules_from_source[source_path] = module
_tfx_module_finder.register_module(module_name, source_path)
except IOError:
raise ImportError('{} in {} not found in '
'import_func_from_source()'.format(
fn_name, source_path))
else:
logging.info('%s is already loaded.', source_path)
return getattr(_imported_modules_from_source[source_path], fn_name)


def import_func_from_module(module_path: Text, fn_name: Text) -> Callable: # pylint: disable=g-bare-generic
"""Imports a function from a module provided as source file or module path."""
user_module = importlib.import_module(module_path)
return getattr(user_module, fn_name)
logging.info('%s is already loaded, reloading', source_path)
module_name = _tfx_module_finder.get_module_name_by_path(source_path)
module = sys.modules[module_name]
importlib.reload(module)
return getattr(module, fn_name)
27 changes: 25 additions & 2 deletions tfx/utils/import_utils_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# Lint as: python2, python3
# Copyright 2019 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -19,7 +18,9 @@
from __future__ import print_function
from __future__ import unicode_literals

import importlib
import os
import sys
# Standard Imports

import tensorflow as tf
Expand All @@ -40,8 +41,8 @@ def testImportFuncFromSource(self):
test_fn_file = os.path.join(source_data_dir, 'test_fn.ext')
fn_1 = import_utils.import_func_from_source(test_fn_file, 'test_fn')
fn_2 = import_utils.import_func_from_source(test_fn_file, 'test_fn')
self.assertIs(fn_1, fn_2)
self.assertEqual(10, fn_1([1, 2, 3, 4]))
self.assertEqual(10, fn_2([1, 2, 3, 4]))

def testImportFuncFromSourceMissingFile(self):
source_data_dir = os.path.join(os.path.dirname(__file__), 'testdata')
Expand Down Expand Up @@ -69,6 +70,28 @@ def testImportFuncFromModuleModuleMissingFunction(self):
_ = import_utils.import_func_from_module(test_fn.test_fn.__module__,
'non_existing_fn')

def testtestImportFuncFromModuleReload(self):
temp_dir = self.create_tempdir().full_path
test_fn_file = os.path.join(temp_dir, 'fn.py')
with tf.io.gfile.GFile(test_fn_file, mode='w') as f:
f.write(
"""def test_fn(inputs):
return sum(inputs)
""")
count_registered = import_utils._tfx_module_finder.count_registered
fn_1 = import_utils.import_func_from_source(test_fn_file, 'test_fn')
self.assertEqual(10, fn_1([1, 2, 3, 4]))
with tf.io.gfile.GFile(test_fn_file, mode='w') as f:
f.write(
"""def test_fn(inputs):
return 1+sum(inputs)
""")
fn_2 = import_utils.import_func_from_source(test_fn_file, 'test_fn')
self.assertEqual(11, fn_2([1, 2, 3, 4]))
fn_3 = getattr(
importlib.reload(sys.modules['user_module_%d' % count_registered]),
'test_fn')
self.assertEqual(11, fn_3([1, 2, 3, 4]))

if __name__ == '__main__':
tf.test.main()

0 comments on commit a960440

Please sign in to comment.