# Licensed under a 3-clause BSD style license - see LICENSE.rst
"""
Various low-level utilities.
"""
import collections
import datetime
import errno
import functools
import json
import math
import multiprocessing
import operator
import os
import re
import select
import shlex
import shutil
import signal
import stat
import subprocess
import sys
import threading
import time
import json5
from asv_runner.util import _human_time_units, human_float, human_time
if not WIN:
from select import PIPE_BUF
[docs]
terminal_width = shutil.get_terminal_size().columns
[docs]
class UserError(Exception):
pass
[docs]
class ParallelFailure(Exception):
"""
Custom exception to work around a multiprocessing bug
https://bugs.python.org/issue9400
"""
def __new__(cls, message, exc_cls, traceback_str):
self = Exception.__new__(cls)
self.message = message
self.exc_cls = exc_cls
self.traceback_str = traceback_str
return self
[docs]
def __reduce__(self):
return (ParallelFailure, (self.message, self.exc_cls, self.traceback_str))
[docs]
def __str__(self):
return "{}: {}\n {}".format(
self.exc_cls.__name__, self.message, self.traceback_str.replace("\n", "\n ")
)
[docs]
def reraise(self):
if self.exc_cls is UserError:
raise UserError(self.message)
else:
raise self
[docs]
def human_list(input_list):
"""
Formats a list of strings in a human-friendly way.
"""
input_list = [f"'{x}'" for x in input_list]
if len(input_list) == 0:
return 'nothing'
elif len(input_list) == 1:
return input_list[0]
elif len(input_list) == 2:
return ' and '.join(input_list)
else:
return ', '.join(input_list[:-1]) + ' and ' + input_list[-1]
[docs]
def human_file_size(size, err=None):
"""
Returns a human-friendly string representing a file size
that is 2-4 characters long.
For example, depending on the number of bytes given, can be one
of::
256b
64k
1.1G
Parameters
----------
size : int
The size of the file (in bytes)
Returns
-------
size : str
A human-friendly representation of the size of the file
"""
size = float(size)
if size < 1:
size = 0.0
suffixes = ' kMGTPEH'
if size == 0:
num_scale = 0
else:
num_scale = math.floor(math.log(size) / math.log(1000))
if num_scale > 7:
suffix = '?'
else:
suffix = suffixes[num_scale].strip()
scale = int(math.pow(1000, num_scale))
value = size / scale
str_value = human_float(value, 3)
if err is None:
return f"{str_value:s}{suffix}"
else:
str_err = human_float(err / scale, 1, truncate_small=2)
return f"{str_value:s}±{str_err:s}{suffix}"
[docs]
def human_value(value, unit, err=None):
"""
Formats a value in a given unit in a human friendly way.
Parameters
----------
value : anything
The value to format
unit : str
The unit the value is in. Currently understands `seconds` and `bytes`.
err : float, optional
Std. error in the value
"""
if isinstance(value, (int, float)):
if value != value:
# nan
display = "n/a"
elif unit == 'seconds':
display = human_time(value, err=err)
elif unit == 'bytes':
display = human_file_size(value, err=err)
else:
display = json.dumps(value)
if err is not None:
display += f"±{err:.2g}"
elif value is None:
display = "failed"
else:
display = json.dumps(value)
return display
[docs]
def parse_human_time(string, base_period='d'):
"""
Parse a human-specified time period to an integer number of seconds.
The following format is accepted: <number><suffix>
Raises a ValueError on parse error.
"""
units = dict(_human_time_units)
units[''] = units[base_period]
suffixes = '|'.join(units.keys())
try:
m = re.match(rf'^\s*([0-9.]+)\s*({suffixes})\s*$', string)
if m is None:
raise ValueError()
return float(m.group(1)) * units[m.group(2)]
except ValueError:
raise ValueError(f"{string!r} is not a valid time period (valid units: {suffixes})")
[docs]
def which(filename, paths=None):
"""
Emulates the UNIX `which` command in Python.
Raises an OSError if no result is found.
"""
# Hide traceback from expected exceptions in pytest reports
__tracebackhide__ = operator.methodcaller('errisinstance', OSError)
if os.path.sep in filename:
locations = ['']
elif paths is not None:
locations = paths
else:
locations = os.environ.get("PATH", "").split(os.pathsep)
if WIN:
# On windows, an entry in %PATH% may be quoted
locations = [
path[1:-1] if len(path) > 2 and path[0] == path[-1] == '"' else path
for path in locations
]
if WIN:
filenames = [filename + ext for ext in ('.exe', '.bat', '.com', '')]
else:
filenames = [filename]
candidates = []
for location in locations:
for filename in filenames:
candidate = os.path.join(location, filename)
if os.path.isfile(candidate) or os.path.islink(candidate):
candidates.append(candidate)
if len(candidates) == 0:
if paths is None:
loc_info = 'PATH'
else:
loc_info = os.pathsep.join(locations)
raise OSError(f"Could not find '{filename}' in {loc_info}")
return candidates[0]
[docs]
def has_command(filename):
"""
Returns `True` if the commandline utility exists.
"""
try:
which(filename)
except OSError:
return False
else:
return True
[docs]
class ProcessError(subprocess.CalledProcessError):
def __init__(self, args, retcode, stdout, stderr):
[docs]
def __str__(self):
if self.retcode == TIMEOUT_RETCODE:
return f"Command '{' '.join(self.args)}' timed out"
else:
return "Command '{}' returned non-zero exit status {}".format(
' '.join(self.args), self.retcode
)
[docs]
def check_call(
args,
valid_return_codes=(0,),
timeout=600,
dots=True,
display_error=True,
shell=False,
env=None,
cwd=None,
):
"""
Runs the given command in a subprocess, raising ProcessError if it
fails.
See `check_output` for parameters.
"""
# Hide traceback from expected exceptions in pytest reports
__tracebackhide__ = operator.methodcaller('errisinstance', ProcessError)
check_output(
args,
valid_return_codes=valid_return_codes,
timeout=timeout,
dots=dots,
display_error=display_error,
shell=shell,
env=env,
cwd=cwd,
)
[docs]
class DebugLogBuffer:
def __init__(self, log):
[docs]
self.linebreak_re = re.compile(b'.*\n')
[docs]
self.lock = threading.Lock()
[docs]
def __call__(self, c):
with self.lock:
self._process(c)
[docs]
def _process(self, c):
if c is None:
text = b"".join(self.buf)
del self.buf[:]
elif b'\n' in c:
m = self.linebreak_re.match(c)
j = m.end()
self.buf.append(c[:j])
text = b"".join(self.buf)
self.buf[:] = [c[j:]]
else:
self.buf.append(c)
return
text = text.decode('utf-8', 'replace')
text = text.removesuffix('\n')
if text:
if self.first:
self.log.debug('OUTPUT -------->', continued=True)
self.first = False
self.log.debug(text, continued=True)
[docs]
def check_output(
args,
valid_return_codes=(0,),
timeout=600,
dots=True,
display_error=True,
shell=False,
return_stderr=False,
env=None,
cwd=None,
redirect_stderr=False,
return_popen=False,
):
"""
Runs the given command in a subprocess, raising ProcessError if it
fails. Returns stdout as a string on success.
Parameters
----------
valid_return_codes : list, optional
A list of return codes to ignore. Defaults to only ignoring zero.
Setting to None ignores all return codes.
timeout : number, optional
Kill the process if it does not produce any output in `timeout`
seconds. If `None`, there is no timeout.
Default: 10 min
dots : bool, optional
If `True` (default) write a dot to the console to show
progress as the subprocess outputs content. May also be
a callback function to call (with no arguments) to indicate
progress.
display_error : bool, optional
If `True` (default) display the stdout and stderr of the
subprocess when the subprocess returns an error code.
shell : bool, optional
If `True`, run the command through the shell. Default is
`False`.
return_stderr : bool, optional
If `True`, return both the (stdout, stderr, errcode) as a
tuple.
env : dict, optional
Specify environment variables for the subprocess.
cwd : str, optional
Specify the current working directory to use when running the
process.
redirect_stderr : bool, optional
Whether to redirect stderr to stdout. In this case the returned
``stderr`` (when return_stderr == True) is an empty string.
return_popen : bool, optional
Whether to return immediately after subprocess.Popen.
Returns
-------
stdout, stderr, retcode : when return_stderr == True
stdout : otherwise
"""
from .console import log
# Hide traceback from expected exceptions in pytest reports
__tracebackhide__ = operator.methodcaller('errisinstance', ProcessError)
def get_content(header=None):
content = []
if header is not None:
content.append(header)
if redirect_stderr:
content.extend(['OUTPUT -------->', stdout[:-1]])
else:
content.extend(['STDOUT -------->', stdout[:-1], 'STDERR -------->', stderr[:-1]])
return '\n'.join(content)
if isinstance(args, str):
args = [args]
log.debug(f"Running '{' '.join(args)}'")
kwargs = {
'shell': shell,
'env': env,
'cwd': cwd,
'stdout': subprocess.PIPE,
'stderr': subprocess.PIPE,
}
if redirect_stderr:
kwargs['stderr'] = subprocess.STDOUT
if WIN:
kwargs['close_fds'] = False
kwargs['creationflags'] = subprocess.CREATE_NEW_PROCESS_GROUP
else:
kwargs['close_fds'] = True
posix = getattr(os, 'setpgid', None)
if posix:
# Run the subprocess in a separate process group, so that we
# can kill it and all child processes it spawns e.g. on
# timeouts. Note that subprocess.Popen will wait until exec()
# before returning in parent process, so there is no race
# condition in setting the process group vs. calls to os.killpg
kwargs['preexec_fn'] = lambda: os.setpgid(0, 0)
proc = subprocess.Popen(args, **kwargs)
if return_popen:
return proc
last_dot_time = time.time()
stdout_chunks = []
stderr_chunks = []
is_timeout = False
if log.is_debug_enabled():
debug_log = DebugLogBuffer(log)
dots = False
else:
def debug_log(c):
return None
if WIN:
start_time = [time.time()]
dot_start_time = start_time[0]
is_timeout = False
def stream_reader(stream, buf):
try:
while not is_timeout:
c = stream.read(1)
if not c:
break
start_time[0] = time.time()
buf.append(c)
debug_log(c)
finally:
stream.close()
stdout_reader = threading.Thread(target=stream_reader, args=(proc.stdout, stdout_chunks))
stdout_reader.daemon = True
stdout_reader.start()
all_threads = [stdout_reader]
if not redirect_stderr:
stderr_reader = threading.Thread(
target=stream_reader, args=(proc.stderr, stderr_chunks)
)
stderr_reader.daemon = True
stderr_reader.start()
all_threads.append(stderr_reader)
# Wait for reader threads
threads = list(all_threads)
while threads:
thread = threads[0]
if timeout is None:
remaining = None
else:
remaining = timeout - (time.time() - start_time[0])
if remaining <= 0:
# Timeout; we won't wait for the thread to join here
if not is_timeout:
is_timeout = True
proc.send_signal(signal.CTRL_BREAK_EVENT)
threads.pop(0)
continue
if dots:
dot_remaining = 0.5 - (time.time() - last_dot_time)
if dot_remaining <= 0:
# Print a dot only if there has been output
if dot_start_time != start_time[0]:
if dots is True:
log.dot()
elif dots:
dots()
dot_start_time = start_time[0]
last_dot_time = time.time()
dot_remaining = 0.5
if remaining is None:
remaining = dot_remaining
else:
remaining = min(dot_remaining, remaining)
thread.join(remaining)
if not thread.is_alive():
threads.pop(0)
if is_timeout:
proc.terminate()
# Wait a bit for the reader threads, if they're alive
for thread in all_threads:
thread.join(0.1)
# Wait for process to exit
proc.wait()
else:
try:
if posix and is_main_thread():
# Forward signals related to Ctrl-Z handling; the child
# process is in a separate process group so it won't receive
# these automatically from the terminal
def sig_forward(signum, frame):
_killpg_safe(proc.pid, signum)
if signum == signal.SIGTSTP:
os.kill(os.getpid(), signal.SIGSTOP)
signal.signal(signal.SIGTSTP, sig_forward)
signal.signal(signal.SIGCONT, sig_forward)
fds = {proc.stdout.fileno(): stdout_chunks}
if not redirect_stderr:
fds[proc.stderr.fileno()] = stderr_chunks
while proc.poll() is None:
try:
if timeout is None:
rlist, wlist, xlist = select.select(list(fds.keys()), [], [])
else:
rlist, wlist, xlist = select.select(list(fds.keys()), [], [], timeout)
except OSError as err:
if err.args[0] == errno.EINTR:
# interrupted by signal handler; try again
continue
raise
if len(rlist) == 0:
# We got a timeout
is_timeout = True
break
for f in rlist:
output = os.read(f, PIPE_BUF)
fds[f].append(output)
debug_log(output)
if dots and time.time() - last_dot_time > 0.5:
if dots is True:
log.dot()
elif dots:
dots()
last_dot_time = time.time()
finally:
if posix and is_main_thread():
# Restore signal handlers
signal.signal(signal.SIGTSTP, signal.SIG_DFL)
signal.signal(signal.SIGCONT, signal.SIG_DFL)
if proc.returncode is None:
# Timeout or another exceptional condition occurred, and
# the program is still running.
if posix:
# Terminate the whole process group
_killpg_safe(proc.pid, signal.SIGTERM)
for _ in range(10):
time.sleep(0.1)
if proc.poll() is not None:
break
else:
# Didn't terminate within 1 sec, so kill it
_killpg_safe(proc.pid, signal.SIGKILL)
else:
proc.terminate()
proc.wait()
proc.stdout.flush()
if not redirect_stderr:
proc.stderr.flush()
stdout_chunks.append(proc.stdout.read())
if not redirect_stderr:
stderr_chunks.append(proc.stderr.read())
proc.stdout.close()
if not redirect_stderr:
proc.stderr.close()
# Flush and disconnect debug log, if any
debug_log(None)
def debug_log(c):
return None
stdout = b''.join(stdout_chunks)
stderr = b''.join(stderr_chunks)
stdout = stdout.decode('utf-8', 'replace')
stderr = stderr.decode('utf-8', 'replace')
if is_timeout:
retcode = TIMEOUT_RETCODE
else:
retcode = proc.returncode
if valid_return_codes is not None and retcode not in valid_return_codes:
header = f"Error running {' '.join(args)} (exit status {retcode})"
if display_error:
if log.is_debug_enabled():
# Output was already printed
log.error(header)
else:
log.error(get_content(header))
raise ProcessError(args, retcode, stdout, stderr)
if return_stderr:
return (stdout, stderr, retcode)
else:
return stdout
[docs]
def _killpg_safe(pgid, signo):
"""
Same as os.killpg, but deal with OSX/BSD
"""
try:
os.killpg(pgid, signo)
except OSError as exc:
if exc.errno == errno.EPERM:
# OSX/BSD may raise EPERM on killpg if the process group
# already terminated
pass
else:
raise
[docs]
def is_main_thread():
"""
Return True if the current thread is the main thread.
"""
return threading.current_thread() == threading.main_thread()
[docs]
def write_json(path, data, api_version=None, compact=False):
"""
Writes JSON to the given path, including indentation and sorting.
Parameters
----------
path : str
File name to write
data : object
Data to serialize as JSON
api_version : int, optional
API version number
compact : bool, optional
Whether to produce compact, non-human readable JSON.
Disables sorting and indentation.
"""
path = os.path.abspath(path)
dirname = long_path(os.path.dirname(path))
if not os.path.exists(dirname):
os.makedirs(dirname)
if api_version is not None:
data = dict(data)
data['version'] = api_version
open_kwargs = {}
open_kwargs['encoding'] = 'utf-8'
with long_path_open(path, 'w', **open_kwargs) as fd:
if not compact:
json.dump(data, fd, indent=4, sort_keys=True)
else:
json.dump(data, fd)
[docs]
def load_json(path, api_version=None, js_comments=False):
"""
Loads JSON from the given path.
Parameters
----------
path : str
File name
api_version : str or None
API version identifier
js_comments : bool, optional
Whether to allow nonstandard javascript-style comments
in the file. Note that this slows down the loading
significantly.
"""
# Hide traceback from expected exceptions in pytest reports
__tracebackhide__ = operator.methodcaller('errisinstance', UserError)
path = os.path.abspath(path)
open_kwargs = {}
open_kwargs['encoding'] = 'utf-8'
with long_path_open(path, 'r', **open_kwargs) as fd:
content = fd.read()
if js_comments:
# strips comments out
data = json5.loads(content)
else:
try:
data = json.loads(content)
except ValueError as err:
raise UserError(f"Error parsing JSON in file '{path}': {err}")
if api_version is not None:
if 'version' in data:
if data['version'] < api_version:
raise UserError(
f"{path} is stored in an old file format. Run `asv update` to update it."
)
elif data['version'] > api_version:
raise UserError(
f"{path} is stored in a format that is newer than "
"what this version of asv understands. Update "
"asv to use this file."
)
del data['version']
else:
raise UserError(f"No version specified in {path}.")
return data
[docs]
def update_json(cls, path, api_version, compact=False):
"""
Perform JSON file format updates.
Parameters
----------
cls : object
Object containing methods update_to_X which updates
the given JSON tree from version X-1 to X.
path : str
Path to JSON file
api_version : int
The current API version
"""
# Hide traceback from expected exceptions in pytest reports
__tracebackhide__ = operator.methodcaller('errisinstance', UserError)
d = load_json(path)
if 'version' not in d:
raise UserError(f"No version specified in {path}.")
if d['version'] < api_version:
for x in range(d['version'] + 1, api_version + 1):
d = getattr(cls, f'update_to_{x}', lambda x: x)(d)
write_json(path, d, api_version, compact=compact)
elif d['version'] > api_version:
raise UserError(
f"{path} is stored in a format that is newer than "
"what this version of asv understands. "
"Upgrade asv in order to use or add to "
"these results."
)
[docs]
def iter_chunks(s, n):
"""
Iterator that returns elements from s in chunks of size n.
"""
chunk = []
for x in s:
chunk.append(x)
if len(chunk) == n:
yield chunk
chunk = []
if len(chunk):
yield chunk
[docs]
def pick_n(items, n):
"""Pick n items, attempting to get equal index spacing."""
if not (n > 0):
raise ValueError("Invalid number of items to pick")
spacing = max(float(len(items)) / n, 1)
spaced = []
i = 0
while int(i) < len(items) and len(spaced) < n:
spaced.append(items[int(i)])
i += spacing
return spaced
[docs]
def get_multiprocessing(parallel):
"""
If parallel indicates that we want to do multiprocessing,
imports the multiprocessing module and sets the parallel
value accordingly.
"""
if parallel != 1:
import multiprocessing
if parallel <= 0:
parallel = multiprocessing.cpu_count()
return parallel, multiprocessing
return parallel, None
[docs]
def iter_subclasses(cls):
"""
Returns all subclasses of a class.
"""
for x in cls.__subclasses__():
yield x
for y in iter_subclasses(x):
yield y
[docs]
def hash_equal(a, b):
"""
Returns `True` if a and b represent the same commit hash.
"""
min_len = min(len(a), len(b))
return a.lower()[:min_len] == b.lower()[:min_len]
[docs]
def get_cpu_info():
"""
Gets a human-friendly description of this machine's CPU.
Returns '' if it can't be obtained.
"""
if sys.platform.startswith('linux'):
with open("/proc/cpuinfo", "rb") as fd:
lines = fd.readlines()
for line in lines:
if b':' in line:
key, val = line.split(b':', 1)
key = key.strip()
val = val.strip()
if key == b'model name':
return val.decode('ascii')
elif sys.platform.startswith('darwin'):
sysctl = which('sysctl')
return check_output([sysctl, '-n', 'machdep.cpu.brand_string']).strip()
elif sys.platform.startswith('win'):
try:
from win32com.client import GetObject
cimv = GetObject(r"winmgmts:root\cimv2")
return cimv.ExecQuery("Select Name from Win32_Processor")[0].name
except Exception:
pass
return ''
[docs]
def get_memsize():
"""
Returns the amount of physical memory in this machine.
Returns '' if it can't be obtained.
"""
if sys.platform.startswith('linux'):
with open("/proc/meminfo", "rb") as fd:
lines = fd.readlines()
for line in lines:
if b':' in line:
key, val = line.split(b':', 1)
key = key.strip()
val = val.strip()
if key == b'MemTotal':
if val.endswith(b' kB'):
units = 1024
else:
units = 1
return int(val.split()[0]) * units
elif sys.platform.startswith('darwin'):
sysctl = which('sysctl')
return int(check_output([sysctl, '-n', 'hw.memsize']).strip())
return ''
[docs]
def format_text_table(rows, num_headers=0, top_header_span_start=0, top_header_text=None):
"""
Format rows in as a reStructuredText table, in the vein of:
.. code-block::
========== ========== ==========
-- top header text, span start 1
---------- ---------------------
row0col0 r0c1 r0c2
========== ========== ==========
row1col0 r1c1 r1c2
row2col0 r2c1 r2c2
========== ========== ==========
"""
# Format content
text_rows = [[f"{item}".replace("\n", " ") for item in row] for row in rows]
# Ensure same number of items on all rows
num_items = max(len(row) for row in text_rows)
for row in text_rows:
row.extend([''] * (num_items - len(row)))
# Determine widths
col_widths = [max(len(row[j]) for row in text_rows) + 2 for j in range(num_items)]
# Pad content
text_rows = [[item.center(w) for w, item in zip(col_widths, row)] for row in text_rows]
# Generate result
headers = [" ".join(row) for row in text_rows[:num_headers]]
content = [" ".join(row) for row in text_rows[num_headers:]]
separator = " ".join("-" * w for w in col_widths)
result = []
if top_header_text is not None:
left_span = "-".join("-" * w for w in col_widths[:top_header_span_start])
right_span = "-".join("-" * w for w in col_widths[top_header_span_start:])
if left_span and right_span:
result += ["--" + " " * (len(left_span) - 1) + top_header_text.center(len(right_span))]
result += [" ".join([left_span, right_span])]
else:
result += [top_header_text.center(len(separator))]
result += ["-".join([left_span, right_span])]
result += headers
result += [separator.replace("-", "=")]
elif headers:
result += headers
result += [separator]
result += content
result = [separator.replace("-", "=")] + result
result += [separator.replace("-", "=")]
return "\n".join(result)
[docs]
def _datetime_to_timestamp(dt, divisor):
delta = dt - datetime.datetime(1970, 1, 1, tzinfo=datetime.timezone.utc)
microseconds = (delta.days * 86400 + delta.seconds) * 10**6 + delta.microseconds
value, remainder = divmod(microseconds, divisor)
if remainder >= divisor // 2:
value += 1
return value
[docs]
def datetime_to_timestamp(dt):
"""
Convert a Python datetime object to a UNIX timestamp.
"""
return _datetime_to_timestamp(dt, 10**6)
[docs]
def datetime_to_js_timestamp(dt):
"""
Convert a Python datetime object to a JavaScript timestamp.
"""
return _datetime_to_timestamp(dt, 10**3)
[docs]
def js_timestamp_to_datetime(ts):
"""
Convert a JavaScript timestamp to a Python datetime object.
"""
return datetime.datetime.fromtimestamp(ts / 1000)
[docs]
def is_nan(x):
"""
Returns `True` if x is a NaN value.
"""
if isinstance(x, float):
return x != x
return False
[docs]
def is_na(value):
"""
Return True if value is None or NaN
"""
return value is None or is_nan(value)
[docs]
def mean_na(values):
"""
Take a mean, with the understanding that None and NaN stand for
missing data.
"""
values = [x for x in values if not is_na(x)]
if values:
return sum(values) / len(values)
else:
return None
[docs]
def geom_mean_na(values):
"""
Compute geometric mean, with the understanding that None and NaN
stand for missing data.
"""
values = [x for x in values if not is_na(x)]
if values:
exponent = 1 / len(values)
prod = 1.0
acc = 0
for x in values:
prod *= abs(x) ** exponent
acc += x
return prod if acc >= 0 else -prod
else:
return None
if not WIN:
long_path_rmtree = shutil.rmtree
def long_path(path):
return path
else:
def long_path(path):
if path.startswith("\\\\"):
return path
return "\\\\?\\" + os.path.abspath(path)
def _remove_readonly(func, path, exc_info):
"""Try harder to remove files on Windows"""
if isinstance(exc_info[1], OSError) and exc_info[1].errno == errno.EACCES:
# Clear read-only flag and try again
try:
os.chmod(path, stat.S_IWRITE | stat.S_IREAD)
func(path)
return
except OSError:
pass
# Reraise original error
raise exc_info[1]
def long_path_open(filename, *a, **kw):
return open(long_path(filename), *a, **kw)
def long_path_rmtree(path, ignore_errors=False):
if ignore_errors:
onerror = None
else:
onerror = _remove_readonly
shutil.rmtree(long_path(path), ignore_errors=ignore_errors, onerror=onerror)
[docs]
def sanitize_filename(filename):
"""
Replace characters to make a string safe to use in file names.
This is not a 1-to-1 mapping.
The implementation needs to match www/asv.js:escape_graph_parameter
"""
if not isinstance(filename, str):
filename = filename.decode(sys.getfilesystemencoding())
# ntfs & ext3
filename = re.sub('[<>:"/\\^|?*\x00-\x1f]', '_', filename)
# ntfs
forbidden = [
"CON",
"PRN",
"AUX",
"NUL",
"COM1",
"COM2",
"COM3",
"COM4",
"COM5",
"COM6",
"COM7",
"COM8",
"COM9",
"LPT1",
"LPT2",
"LPT3",
"LPT4",
"LPT5",
"LPT6",
"LPT7",
"LPT8",
"LPT9",
]
if filename.upper() in forbidden:
filename = filename + "_"
return filename
[docs]
def namedtuple_with_doc(name, slots, doc):
cls = collections.namedtuple(name, slots)
cls.__doc__ = doc
return cls
[docs]
def recvall(sock, size):
"""
Receive data of given size from a socket connection
"""
data = b""
while len(data) < size:
s = sock.recv(size - len(data))
data += s
if not s:
raise RuntimeError(
f"did not receive data from socket (size {size}, got only {data!r})"
)
return data
[docs]
def interpolate_command(command, variables):
"""
Parse a command with interpolated variables to a sequence of commands.
The command is parsed as in posix-style shell (by shlex) and split to
parts. Additional constructs recognized:
- ``ENVVAR=value <command>``: parsed as declaring an environment variable
named 'ENVVAR'.
- ``return-code=value <command>``: parsed as declaring valid return codes.
- ``in-dir=value <command>``: parsed as declaring working directory for command.
Parameters
----------
command : str
Command to execute, posix shell style.
variables : dict
Interpolation variables.
Returns
-------
command : list of str
Command arguments.
env : dict
Environment variables declared in the command.
return_codes : {set, int, None}
Valid return codes.
cwd : {str, None}
Current working directory for the command, if any.
"""
parts = shlex.split(command)
try:
result = [c.format(**variables) for c in parts]
except KeyError as exc:
raise UserError(
f"Configuration error: {{{exc.args[0]}}} not available "
f"when substituting into command {command!r} "
f"Available: {variables!r}"
)
env = {}
return_codes_set = False
return_codes = {0}
cwd = None
while result:
m = re.match('^([A-Za-z_][A-Za-z0-9_]*)=(.*)$', result[0])
if m:
env[m.group(1)] = m.group(2)
del result[0]
continue
if result[0].startswith('return-code='):
if return_codes_set:
raise UserError(
"Configuration error: multiple return-code specifications "
f"in command {command!r} "
)
break
if result[0] == 'return-code=any':
return_codes = None
return_codes_set = True
del result[0]
continue
m = re.match('^return-code=([0-9,]+)$', result[0])
if m:
try:
return_codes = {int(x) for x in m.group(1).split(",")}
return_codes_set = True
del result[0]
continue
except ValueError:
pass
raise UserError(
"Configuration error: invalid return-code specification "
f"{result[0]!r} when substituting into command {command!r} "
)
if result[0].startswith('in-dir='):
if cwd is not None:
raise UserError(
f"Configuration error: multiple in-dir specifications in command {command!r} "
)
break
cwd = result[0][7:]
del result[0]
continue
break
return result, env, return_codes, cwd
[docs]
def truncate_float_list(item, digits=5):
"""
Truncate floating-point numbers (in a possibly nested list)
to given significant digits, for a shorter base-10
representation.
"""
if isinstance(item, float):
fmt = f'{{:.{digits - 1}e}}'
return float(fmt.format(item))
elif isinstance(item, list):
return [truncate_float_list(x, digits) for x in item]
else:
return item
[docs]
def _init_global_locks(lock_dict, env):
"""Initialize global locks in a new multiprocessing process
Also inherit the base environment even if using a forkserver"""
_global_locks.update(lock_dict)
os.environ.update(env)
[docs]
def new_multiprocessing_lock(name):
"""Create a new global multiprocessing lock"""
_global_locks[name] = multiprocessing.Lock()
[docs]
def get_multiprocessing_lock(name):
"""Get an existing global multiprocessing lock"""
return _global_locks[name]
[docs]
def get_multiprocessing_pool(parallel=None):
"""Create a multiprocessing.Pool, managing global locks properly"""
env = os.environ.copy()
return multiprocessing.Pool(parallel, initializer=_init_global_locks, initargs=(_global_locks, env))
try:
from shlex import quote as shlex_quote
except ImportError:
[docs]
_find_unsafe = re.compile(r'[^\w@%+=:,./-]').search
def shlex_quote(s):
"""Return a shell-escaped version of the string *s*."""
if not s:
return "''"
if _find_unsafe(s) is None:
return s
# use single quotes, and put single quotes into double quotes
# the string $'b is then quoted as '$'"'"'b'
return "'" + s.replace("'", "'\"'\"'") + "'"
[docs]
def git_default_branch():
try:
# Local name gets precedence
default_branch = check_output(
[which('git'), 'config', 'init.defaultBranch'], display_error=False
).strip()
except ProcessError:
# Check global
try:
default_branch = check_output(
[which('git'), 'config', '--global', 'init.defaultBranch'], display_error=False
).strip()
except ProcessError:
# Check system
try:
default_branch = check_output(
[which('git'), 'config', '--system', 'init.defaultBranch'], display_error=False
).strip()
except ProcessError:
# Default to master when global and system are not set
default_branch = 'master'
return default_branch
[docs]
def search_channels(cli_path, pkg, version):
try:
result = subprocess.run(
[cli_path, "search", f"{pkg}=={version}"], capture_output=True, text=True, check=False
)
except subprocess.CalledProcessError as e:
print(f"Error searching for {pkg} {version}, got:\n {e}", file=sys.stderr)
return False
if f"No match found for: {pkg}=={version}." in result.stdout:
return False
# Worked!
return True
[docs]
class ParsedPipDeclaration:
def __init__(self, declaration):
[docs]
self.specification = None
[docs]
self.is_editable = False
self._parse_declaration(declaration)
if not self.pkgname and not self.path:
raise ValueError(
"Either a valid package name or a path must be present in the declaration."
)
[docs]
def _parse_declaration(self, declaration):
# Match flags with values
flag_with_value_pattern = (
r'(--[\w-]+=' # Match the flag name
r'\"[^\"]+\")' # Match the value in double quotes
)
flag_values = re.findall(flag_with_value_pattern, declaration)
for flag_value in flag_values:
self.flags.append(flag_value)
declaration = declaration.replace(flag_value, '', 1)
# Match git URLs
git_url_pattern = (
r'(git\+https:\/\/[a-zA-Z0-9-_\/.]+)' # match the git URL
r'(?:@([a-zA-Z0-9-_\/.]+))?' # optional branch or tag
r'(?:#egg=([a-zA-Z0-9-_]+))?' # optional egg fragment
)
git_url_match = re.search(git_url_pattern, declaration)
# If there's a git URL match, remove it from the declaration
if git_url_match:
self.path = git_url_match.group(1)
branch_or_tag = git_url_match.group(2)
if branch_or_tag:
self.path += f"@{branch_or_tag}"
if git_url_match.group(3):
self.pkgname = git_url_match.group(3)
declaration = declaration.replace(git_url_match.group(0), '', 1)
# Match local paths
local_pattern = (
r'(\.\/[a-zA-Z0-9-_]+\/?' # Relative path starting with ./
r'|\.\.\/[a-zA-Z0-9-_]+\/?' # Relative path starting with ../
r'|\/\w+\/?)' # Absolute path
)
local_match = re.search(local_pattern, declaration)
# If there's a local path match, remove it from the declaration
if local_match:
self.path = local_match.group(1)
declaration = declaration.replace(local_match.group(0), '', 1)
# Match flags
flags_pattern = (
r'(?:^|\s)' # Match start or whitespace
r'(-[a-zA-Z]|' # Single-letter flags
r'--\w+(?:-\w+)*)' # Double-dash flags
)
flags = re.findall(flags_pattern, declaration)
if flags:
self.flags.extend(flags)
if "-e" in self.flags:
self.is_editable = True
# Remove matched flags from declaration
for flag in self.flags:
declaration = declaration.replace(flag, '', 1)
# Match package details
pkg_pattern = (
r'(?P<name>[a-zA-Z0-9-_]+)' # Name
r'(' # Start group for version specification(s)
r'((?P<specifier>[<>!=~]{1,2})' # Version specifier
r'(?P<version>[0-9.a-zA-Z_-]+))' # Version
r'((?P<multi_spec>,[<>!=~]{1,2}[0-9.a-zA-Z_-]+)*)?' # Multiple version specifications
r')?' # End optional group for version specification(s)
)
pkg_match = re.search(pkg_pattern, declaration)
# Populate attributes based on package details matches
if pkg_match:
self.pkgname = pkg_match.group("name")
specifier = pkg_match.group("specifier")
version = pkg_match.group("version")
multi_spec = pkg_match.group("multi_spec")
# If a specifier is present, prioritize it and consume version(s) greedily
if specifier:
self.specification = f"{specifier}{version}{multi_spec or ''}"
else:
declaration = declaration.replace(pkg_match.group("name"), '', 1)
version_match = re.search(r'(?P<version>\d+(\.\d+)*([a-zA-Z0-9]+)?)', declaration)
if version_match:
self.specification = f"=={version_match.group(0)}"
[docs]
def construct_pip_call(pip_caller, parsed_declaration: ParsedPipDeclaration):
pargs = ['install', '-v', '--upgrade']
if parsed_declaration.flags:
pargs += parsed_declaration.flags
if parsed_declaration.path:
pargs.append(parsed_declaration.path)
elif parsed_declaration.pkgname:
if parsed_declaration.specification:
pargs.append(f"{parsed_declaration.pkgname}{parsed_declaration.specification}")
else:
pargs.append(parsed_declaration.pkgname)
return functools.partial(pip_caller, pargs)
if hasattr(sys, 'pypy_version_info'):
else:
ON_PYPY = False
[docs]
def get_matching_environment(environments, result=None):
return next(
(
env
for env in environments
if (result is None or result.env_name == env.name)
and env_py_is_sys_version(env.python)
),
None,
)
[docs]
def replace_cpython_version(arg, new_version):
match = re.match(r"^python(\W|$)", arg)
if match and not match.group(1).isalnum():
return f"python={new_version}"
else:
return arg
[docs]
def env_py_is_sys_version(env_python):
return extract_cpython_version(env_python) == "{}.{}".format(*sys.version_info[:2])