import os
import sys
import tarfile
import zipfile
from datetime import datetime
from io import BytesIO, TextIOWrapper
from pathlib import Path
from typing import Dict, Iterable, Optional, Union
import yaml
[docs]class Archiver:
def __init__(self, file, mode="r"):
if Path(file).suffix == ".tar":
self.type = "tar"
elif Path(file).suffix == ".tgz" or Path(file).suffixes == [".tar", ".gz"]:
self.type = "tar"
if mode == "w":
mode = "w:gz"
elif Path(file).suffix == ".tbz2" or Path(file).suffixes == [".tar", ".bz2"]:
self.type = "tar"
if mode == "w":
mode = "w:bz2"
elif Path(file).suffix == ".txz" or Path(file).suffixes == [".tar", ".xz"]:
self.type = "tar"
if mode == "w":
mode = "w:xz"
elif Path(file).suffix == ".zip":
self.type = "zip"
else:
raise ValueError(f"Cannot detect archive format: type={file}")
if self.type == "tar":
self.fopen = tarfile.open(file, mode=mode)
elif self.type == "zip":
self.fopen = zipfile.ZipFile(file, mode=mode)
else:
raise ValueError(f"Not supported: type={type}")
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.fopen.close()
[docs] def close(self):
self.fopen.close()
def __iter__(self):
if self.type == "tar":
return iter(self.fopen)
elif self.type == "zip":
return iter(self.fopen.infolist())
else:
raise ValueError(f"Not supported: type={self.type}")
[docs] def add(self, filename, arcname=None, recursive: bool = True):
if arcname is not None:
print(f"adding: {arcname}")
else:
print(f"adding: {filename}")
if recursive and Path(filename).is_dir():
for f in Path(filename).glob("**/*"):
if f.is_dir():
continue
if arcname is not None:
_arcname = Path(arcname) / f
else:
_arcname = None
self.add(f, _arcname)
return
if self.type == "tar":
return self.fopen.add(filename, arcname)
elif self.type == "zip":
return self.fopen.write(filename, arcname)
else:
raise ValueError(f"Not supported: type={self.type}")
[docs] def addfile(self, info, fileobj):
print(f"adding: {self.get_name_from_info(info)}")
if self.type == "tar":
return self.fopen.addfile(info, fileobj)
elif self.type == "zip":
return self.fopen.writestr(info, fileobj.read())
else:
raise ValueError(f"Not supported: type={self.type}")
[docs] def generate_info(self, name, size) -> Union[tarfile.TarInfo, zipfile.ZipInfo]:
"""Generate TarInfo using system information"""
if self.type == "tar":
tarinfo = tarfile.TarInfo(str(name))
if os.name == "posix":
tarinfo.gid = os.getgid()
tarinfo.uid = os.getuid()
tarinfo.mtime = datetime.now().timestamp()
tarinfo.size = size
# Keep mode as default
return tarinfo
elif self.type == "zip":
zipinfo = zipfile.ZipInfo(str(name), datetime.now().timetuple()[:6])
zipinfo.file_size = size
return zipinfo
else:
raise ValueError(f"Not supported: type={self.type}")
[docs] def get_name_from_info(self, info):
if self.type == "tar":
assert isinstance(info, tarfile.TarInfo), type(info)
return info.name
elif self.type == "zip":
assert isinstance(info, zipfile.ZipInfo), type(info)
return info.filename
else:
raise ValueError(f"Not supported: type={self.type}")
[docs]def find_path_and_change_it_recursive(value, src: str, tgt: str):
if isinstance(value, dict):
return {
k: find_path_and_change_it_recursive(v, src, tgt) for k, v in value.items()
}
elif isinstance(value, (list, tuple)):
return [find_path_and_change_it_recursive(v, src, tgt) for v in value]
elif isinstance(value, str) and Path(value) == Path(src):
return tgt
else:
return value
[docs]def get_dict_from_cache(meta: Union[Path, str]) -> Optional[Dict[str, str]]:
meta = Path(meta)
outpath = meta.parent.parent
if not meta.exists():
return None
with meta.open("r", encoding="utf-8") as f:
d = yaml.safe_load(f)
assert isinstance(d, dict), type(d)
yaml_files = d["yaml_files"]
files = d["files"]
assert isinstance(yaml_files, dict), type(yaml_files)
assert isinstance(files, dict), type(files)
retval = {}
for key, value in list(yaml_files.items()) + list(files.items()):
if not (outpath / value).exists():
return None
retval[key] = str(outpath / value)
return retval
[docs]def unpack(
input_archive: Union[Path, str],
outpath: Union[Path, str],
use_cache: bool = True,
) -> Dict[str, str]:
"""Scan all files in the archive file and return as a dict of files.
Examples:
tarfile:
model.pth
some1.file
some2.file
>>> unpack("tarfile", "out")
{'asr_model_file': 'out/model.pth'}
"""
input_archive = Path(input_archive)
outpath = Path(outpath)
with Archiver(input_archive) as archive:
for info in archive:
if Path(archive.get_name_from_info(info)).name == "meta.yaml":
if (
use_cache
and (outpath / Path(archive.get_name_from_info(info))).exists()
):
retval = get_dict_from_cache(
outpath / Path(archive.get_name_from_info(info))
)
if retval is not None:
return retval
d = yaml.safe_load(archive.extractfile(info))
assert isinstance(d, dict), type(d)
yaml_files = d["yaml_files"]
files = d["files"]
assert isinstance(yaml_files, dict), type(yaml_files)
assert isinstance(files, dict), type(files)
break
else:
raise RuntimeError("Format error: not found meta.yaml")
for info in archive:
fname = archive.get_name_from_info(info)
outname = outpath / fname
outname.parent.mkdir(parents=True, exist_ok=True)
if fname in set(yaml_files.values()):
d = yaml.safe_load(archive.extractfile(info))
# Rewrite yaml
for info2 in archive:
name = archive.get_name_from_info(info2)
d = find_path_and_change_it_recursive(d, name, str(outpath / name))
with outname.open("w", encoding="utf-8") as f:
yaml.safe_dump(d, f)
else:
archive.extract(info, path=outpath)
retval = {}
for key, value in list(yaml_files.items()) + list(files.items()):
retval[key] = str(outpath / value)
return retval
def _to_relative_or_resolve(f):
# Resolve to avoid symbolic link
p = Path(f).resolve()
try:
# Change to relative if it can
p = p.relative_to(Path(".").resolve())
except ValueError:
pass
return str(p)
[docs]def pack(
files: Dict[str, Union[str, Path]],
yaml_files: Dict[str, Union[str, Path]],
outpath: Union[str, Path],
option: Iterable[Union[str, Path]] = (),
):
for v in list(files.values()) + list(yaml_files.values()) + list(option):
if not Path(v).exists():
raise FileNotFoundError(f"No such file or directory: {v}")
files = {k: _to_relative_or_resolve(v) for k, v in files.items()}
yaml_files = {k: _to_relative_or_resolve(v) for k, v in yaml_files.items()}
option = [_to_relative_or_resolve(v) for v in option]
meta_objs = dict(
files=files,
yaml_files=yaml_files,
timestamp=datetime.now().timestamp(),
python=sys.version,
)
try:
import torch
meta_objs.update(torch=str(torch.__version__))
except ImportError:
pass
try:
import espnet
meta_objs.update(espnet=espnet.__version__)
except ImportError:
pass
Path(outpath).parent.mkdir(parents=True, exist_ok=True)
with Archiver(outpath, mode="w") as archive:
# Write packed/meta.yaml
fileobj = BytesIO(yaml.safe_dump(meta_objs).encode())
info = archive.generate_info("meta.yaml", fileobj.getbuffer().nbytes)
archive.addfile(info, fileobj=fileobj)
for f in list(yaml_files.values()) + list(files.values()) + list(option):
archive.add(f)
print(f"Generate: {outpath}")