Improve archive unpacker

This commit is contained in:
Ivan Kravets
2017-09-24 00:33:12 +03:00
parent 837b040761
commit d9ae367281
2 changed files with 23 additions and 16 deletions

View File

@ -187,8 +187,8 @@ class PkgInstallerMixin(object):
@staticmethod @staticmethod
def unpack(source_path, dest_dir): def unpack(source_path, dest_dir):
fu = FileUnpacker(source_path, dest_dir) with FileUnpacker(source_path) as fu:
return fu.start() return fu.unpack(dest_dir)
@staticmethod @staticmethod
def get_install_dirname(manifest): def get_install_dirname(manifest):

View File

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from os import chmod from os import chmod
from os.path import join, splitext from os.path import join
from tarfile import open as tarfile_open from tarfile import open as tarfile_open
from time import mktime from time import mktime
from zipfile import ZipFile from zipfile import ZipFile
@ -39,6 +39,9 @@ class ArchiveBase(object):
def after_extract(self, item, dest_dir): def after_extract(self, item, dest_dir):
pass pass
def close(self):
self._afo.close()
class TARArchive(ArchiveBase): class TARArchive(ArchiveBase):
@ -76,28 +79,32 @@ class ZIPArchive(ArchiveBase):
class FileUnpacker(object): class FileUnpacker(object):
def __init__(self, archpath, dest_dir="."): def __init__(self, archpath):
self._archpath = archpath self.archpath = archpath
self._dest_dir = dest_dir
self._unpacker = None self._unpacker = None
_, archext = splitext(archpath.lower()) def __enter__(self):
if archext in (".gz", ".bz2"): if self.archpath.lower().endswith((".gz", ".bz2")):
self._unpacker = TARArchive(archpath) self._unpacker = TARArchive(self.archpath)
elif archext == ".zip": elif self.archpath.lower().endswith(".zip"):
self._unpacker = ZIPArchive(archpath) self._unpacker = ZIPArchive(self.archpath)
if not self._unpacker: if not self._unpacker:
raise UnsupportedArchiveType(archpath) raise UnsupportedArchiveType(self.archpath)
return self
def start(self): def __exit__(self, *args):
if self._unpacker:
self._unpacker.close()
def unpack(self, dest_dir="."):
assert self._unpacker
if app.is_disabled_progressbar(): if app.is_disabled_progressbar():
click.echo("Unpacking...") click.echo("Unpacking...")
for item in self._unpacker.get_items(): for item in self._unpacker.get_items():
self._unpacker.extract_item(item, self._dest_dir) self._unpacker.extract_item(item, dest_dir)
else: else:
items = self._unpacker.get_items() items = self._unpacker.get_items()
with click.progressbar(items, label="Unpacking") as pb: with click.progressbar(items, label="Unpacking") as pb:
for item in pb: for item in pb:
self._unpacker.extract_item(item, self._dest_dir) self._unpacker.extract_item(item, dest_dir)
return True return True