Skip to content

Commit 3fc7718

Browse files
committed
update module importer
1 parent e25fc10 commit 3fc7718

File tree

1 file changed

+82
-72
lines changed

1 file changed

+82
-72
lines changed

metaflow/plugins/env_escape/client_modules.py

Lines changed: 82 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import atexit
22
import importlib
3+
import importlib.util
34
import itertools
45
import pickle
56
import re
@@ -41,6 +42,8 @@ def __init__(self, loader, prefix, exports, client):
4142
def __getattr__(self, name):
4243
if name == "__loader__":
4344
return self._loader
45+
if name == "__spec__":
46+
return importlib.util.spec_from_loader(self._prefix, self._loader)
4447
if name in ("__name__", "__package__"):
4548
return self._prefix
4649
if name in ("__file__", "__path__"):
@@ -71,7 +74,8 @@ def func(*args, **kwargs):
7174
# Try to see if this is a submodule that we can load
7275
m = None
7376
try:
74-
m = self._loader.load_module(".".join([self._prefix, name]))
77+
submodule_name = ".".join([self._prefix, name])
78+
m = importlib.import_module(submodule_name)
7579
except ImportError:
7680
pass
7781
if m is None:
@@ -117,7 +121,7 @@ def __setattr__(self, name, value):
117121

118122

119123
class ModuleImporter(object):
120-
# This ModuleImporter implements the Importer Protocol defined in PEP 302
124+
# This ModuleImporter implements the MetaPathFinder and Loader protocols (PEP 302/451)
121125
def __init__(
122126
self,
123127
python_executable,
@@ -135,84 +139,90 @@ def __init__(
135139
self._handled_modules = None
136140
self._aliases = {}
137141

138-
def find_module(self, fullname, path=None):
142+
def find_spec(self, fullname, path=None, target=None):
139143
if self._handled_modules is not None:
140144
if get_canonical_name(fullname, self._aliases) in self._handled_modules:
141-
return self
145+
return importlib.util.spec_from_loader(fullname, self)
142146
return None
143147
if any([fullname.startswith(prefix) for prefix in self._module_prefixes]):
144148
# We potentially handle this
145-
return self
149+
return importlib.util.spec_from_loader(fullname, self)
146150
return None
147151

148-
def load_module(self, fullname):
149-
if fullname in sys.modules:
150-
return sys.modules[fullname]
151-
if self._client is None:
152-
if sys.version_info[0] < 3:
153-
raise NotImplementedError(
154-
"Environment escape imports are not supported in Python 2"
155-
)
156-
# We initialize a client and query the modules we handle
157-
# The max_pickle_version is the pickle version that the server (so
158-
# the underlying interpreter we call into) supports; we determine
159-
# what version the current environment support and take the minimum
160-
# of those two
161-
max_pickle_version = min(self._max_pickle_version, pickle.HIGHEST_PROTOCOL)
162-
163-
self._client = Client(
164-
self._module_prefixes,
165-
self._python_executable,
166-
self._pythonpath,
167-
max_pickle_version,
168-
self._config_dir,
169-
)
170-
atexit.register(_clean_client, self._client)
171-
172-
# Get information about overrides and what the server knows about
173-
exports = self._client.get_exports()
174-
175-
prefixes = set()
176-
export_classes = exports.get("classes", [])
177-
export_functions = exports.get("functions", [])
178-
export_values = exports.get("values", [])
179-
export_exceptions = exports.get("exceptions", [])
180-
self._aliases = exports.get("aliases", {})
181-
for name in itertools.chain(
182-
export_classes,
183-
export_functions,
184-
export_values,
185-
(e[0] for e in export_exceptions),
186-
):
187-
splits = name.rsplit(".", 1)
188-
prefixes.add(splits[0])
189-
# We will make sure that we create modules even for "empty" prefixes
190-
# because packages are always loaded hierarchically so if we have
191-
# something in `a.b.c` but nothing directly in `a`, we still need to
192-
# create a module named `a`. There is probably a better way of doing this
193-
all_prefixes = list(prefixes)
194-
for prefix in all_prefixes:
195-
parts = prefix.split(".")
196-
cur = parts[0]
197-
for i in range(1, len(parts)):
198-
prefixes.add(cur)
199-
cur = ".".join([cur, parts[i]])
200-
201-
# We now know all the modules that we can handle. We update
202-
# handled_module and return the module if we have it or raise ImportError
203-
self._handled_modules = {}
204-
for prefix in prefixes:
205-
self._handled_modules[prefix] = _WrappedModule(
206-
self, prefix, exports, self._client
207-
)
152+
def create_module(self, spec):
153+
# Return None to use default module creation
154+
return None
155+
156+
def exec_module(self, module):
157+
fullname = module.__name__
158+
159+
self._initialize_client()
160+
208161
canonical_fullname = get_canonical_name(fullname, self._aliases)
209-
# Modules are created canonically but we need to return something for any
210-
# of the aliases.
211-
module = self._handled_modules.get(canonical_fullname)
212-
if module is None:
213-
raise ImportError
214-
sys.modules[fullname] = module
215-
return module
162+
# Modules are created canonically but we need to handle any of the aliases.
163+
wrapped_module = self._handled_modules.get(canonical_fullname)
164+
if wrapped_module is None:
165+
raise ImportError(f"No module named '{fullname}'")
166+
167+
# Replace the standard module with the wrapped module in sys.modules
168+
sys.modules[fullname] = wrapped_module
169+
170+
def _initialize_client(self):
171+
if self._client is not None:
172+
return
173+
174+
# We initialize a client and query the modules we handle
175+
# The max_pickle_version is the pickle version that the server (so
176+
# the underlying interpreter we call into) supports; we determine
177+
# what version the current environment support and take the minimum
178+
# of those two
179+
max_pickle_version = min(self._max_pickle_version, pickle.HIGHEST_PROTOCOL)
180+
181+
self._client = Client(
182+
self._module_prefixes,
183+
self._python_executable,
184+
self._pythonpath,
185+
max_pickle_version,
186+
self._config_dir,
187+
)
188+
atexit.register(_clean_client, self._client)
189+
190+
# Get information about overrides and what the server knows about
191+
exports = self._client.get_exports()
192+
193+
prefixes = set()
194+
export_classes = exports.get("classes", [])
195+
export_functions = exports.get("functions", [])
196+
export_values = exports.get("values", [])
197+
export_exceptions = exports.get("exceptions", [])
198+
self._aliases = exports.get("aliases", {})
199+
for name in itertools.chain(
200+
export_classes,
201+
export_functions,
202+
export_values,
203+
(e[0] for e in export_exceptions),
204+
):
205+
splits = name.rsplit(".", 1)
206+
prefixes.add(splits[0])
207+
# We will make sure that we create modules even for "empty" prefixes
208+
# because packages are always loaded hierarchically so if we have
209+
# something in `a.b.c` but nothing directly in `a`, we still need to
210+
# create a module named `a`. There is probably a better way of doing this
211+
all_prefixes = list(prefixes)
212+
for prefix in all_prefixes:
213+
parts = prefix.split(".")
214+
cur = parts[0]
215+
for i in range(1, len(parts)):
216+
prefixes.add(cur)
217+
cur = ".".join([cur, parts[i]])
218+
219+
# We now know all the modules that we can handle. We update
220+
# handled_module and return the module if we have it or raise ImportError
221+
self._handled_modules = {}
222+
for prefix in prefixes:
223+
self._handled_modules[prefix] = _WrappedModule(
224+
self, prefix, exports, self._client
225+
)
216226

217227

218228
def create_modules(python_executable, pythonpath, max_pickle_version, path, prefixes):

0 commit comments

Comments
 (0)