Skip to content

Commit

Permalink
Ensure base path for included modules kept in sys.path (#406)
Browse files Browse the repository at this point in the history
Co-authored-by: Akshita Bhagia <[email protected]>
  • Loading branch information
epwalsh and AkshitaB authored Sep 23, 2022
1 parent e1a1cd1 commit b1d6431
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 38 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed a bug where creating a `StepInfo` object from params might result in unnecessary imports.
- Fixed a bug where canceling the Beaker executor might not work properly.
- Fixed a bug where the trainer trains too much when `train_epochs` is set and you're using gradient accumulation.
- Fixed a bug where included modules might not be found when using multiprocessing when they're not on `sys.path` / `PYTHONPATH`.
- Fixed how the results of uncacheable steps are displayed by `tango run`.
- Beaker executor won't run duplicate cacheable steps at the same time.

Expand Down
58 changes: 20 additions & 38 deletions tango/common/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,13 @@
import sys
import traceback
from collections import OrderedDict
from contextlib import contextmanager
from dataclasses import asdict, is_dataclass
from datetime import datetime, tzinfo
from pathlib import Path
from typing import Any, Iterable, Optional, Set, Tuple, Union

import pytz

from .aliases import PathOrStr
from .exceptions import SigTermReceived


Expand All @@ -35,25 +33,6 @@ def install_sigterm_handler():
signal.signal(signal.SIGTERM, _handle_sigterm)


@contextmanager
def push_python_path(path: PathOrStr):
"""
Prepends the given path to `sys.path`.
This method is intended to use with `with`, so after its usage, its value willbe removed from
`sys.path`.
"""
# In some environments, such as TC, it fails when sys.path contains a relative path, such as ".".
path = Path(path).resolve()
path = str(path)
sys.path.insert(0, path)
try:
yield
finally:
# Better to remove by value, in case `sys.path` was manipulated in between.
sys.path.remove(path)


_extra_imported_modules: Set[str] = set()


Expand Down Expand Up @@ -108,29 +87,32 @@ def import_module_and_submodules(package_name: str, exclude: Optional[Set[str]]
package_name, base_path = resolve_module_name(package_name)
else:
base_path = Path(".")
base_path = base_path.resolve()

if exclude and package_name in exclude:
return

importlib.invalidate_caches()

# For some reason, python doesn't always add this by default to your path, but you pretty much
# always want it when using `--include-package`. And if it's already there, adding it again at
# the end won't hurt anything.
with push_python_path(base_path):
# Import at top level
module = importlib.import_module(package_name)
path = getattr(module, "__path__", [])
path_string = "" if not path else path[0]

# walk_packages only finds immediate children, so need to recurse.
for module_finder, name, _ in pkgutil.walk_packages(path):
# Sometimes when you import third-party libraries that are on your path,
# `pkgutil.walk_packages` returns those too, so we need to skip them.
if path_string and module_finder.path != path_string: # type: ignore[union-attr]
continue
subpackage = f"{package_name}.{name}"
import_module_and_submodules(subpackage, exclude=exclude)
# Ensure `base_path` is first in `sys.path`.
if str(base_path) not in sys.path:
sys.path.insert(0, str(base_path))
else:
sys.path.insert(0, sys.path.pop(sys.path.index(str(base_path))))

# Import at top level
module = importlib.import_module(package_name)
path = getattr(module, "__path__", [])
path_string = "" if not path else path[0]

# walk_packages only finds immediate children, so need to recurse.
for module_finder, name, _ in pkgutil.walk_packages(path):
# Sometimes when you import third-party libraries that are on your path,
# `pkgutil.walk_packages` returns those too, so we need to skip them.
if path_string and module_finder.path != path_string: # type: ignore[union-attr]
continue
subpackage = f"{package_name}.{name}"
import_module_and_submodules(subpackage, exclude=exclude)


def _parse_bool(value: Union[bool, str]) -> bool:
Expand Down

0 comments on commit b1d6431

Please sign in to comment.