Skip to content

Commit 6bde7d0

Browse files
author
Dan
committed
Added recursive copy functionality to parallel client. Added recursive copy test for parallel client. Fixed docstring indendations
1 parent e685a3e commit 6bde7d0

File tree

4 files changed

+60
-16
lines changed

4 files changed

+60
-16
lines changed

pssh/pssh_client.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -472,34 +472,37 @@ def get_stdout(self, greenlet, return_buffers=False):
472472
'stdout' : stdout,
473473
'stderr' : stderr, }}
474474

475-
def copy_file(self, local_file, remote_file):
475+
def copy_file(self, local_file, remote_file, recurse=False):
476476
"""Copy local file to remote file in parallel
477477
478478
:param local_file: Local filepath to copy to remote host
479479
:type local_file: str
480480
:param remote_file: Remote filepath on remote host to copy file to
481481
:type remote_file: str
482-
482+
:param recurse: Whether or not to descend into directories recursively.
483+
:type recurse: bool
484+
485+
:raises: :mod:'ValueError' when a directory is supplied to local_file \
486+
and recurse is not set
487+
488+
483489
.. note ::
484490
Remote directories in `remote_file` that do not exist will be
485491
created as long as permissions allow.
486-
487-
.. note ::
488-
Path separation is handled client side so it is possible to copy
489-
to/from hosts with differing path separators, like from/to Linux
490-
and Windows.
491-
492+
492493
:rtype: List(:mod:`gevent.Greenlet`) of greenlets for remote copy \
493494
commands
494495
"""
495-
return [self.pool.spawn(self._copy_file, host, local_file, remote_file)
496+
return [self.pool.spawn(self._copy_file, host, local_file, remote_file,
497+
{'recurse' : recurse})
496498
for host in self.hosts]
497499

498-
def _copy_file(self, host, local_file, remote_file):
500+
def _copy_file(self, host, local_file, remote_file, recurse=False):
499501
"""Make sftp client, copy file"""
500502
if not self.host_clients[host]:
501-
self.host_clients[host] = SSHClient(host, user=self.user,
502-
password=self.password,
503-
port=self.port, pkey=self.pkey,
504-
forward_ssh_agent=self.forward_ssh_agent)
505-
return self.host_clients[host].copy_file(local_file, remote_file)
503+
self.host_clients[host] = SSHClient(
504+
host, user=self.user, password=self.password,
505+
port=self.port, pkey=self.pkey,
506+
forward_ssh_agent=self.forward_ssh_agent)
507+
return self.host_clients[host].copy_file(local_file, remote_file,
508+
recurse=recurse)

pssh/ssh_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ def copy_file(self, local_file, remote_file, recurse=False):
308308
:type remote_file: str
309309
:param recurse: Whether or not to descend into directories recursively.
310310
:type recurse: bool
311-
311+
312312
:raises: :mod:'ValueError' when a directory is supplied to local_file \
313313
and recurse is not set
314314
"""

tests/test_pssh_client.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import paramiko
3232
import os
3333
import warnings
34+
import shutil
3435

3536
USER_KEY = paramiko.RSAKey.from_private_key_file(
3637
os.path.sep.join([os.path.dirname(__file__), 'test_client_private_key']))
@@ -342,6 +343,36 @@ def test_pssh_copy_file(self):
342343
del client
343344
server.join()
344345

346+
def test_pssh_client_directory(self):
347+
"""Tests copying directories with SSH client. Copy all the files from
348+
local directory to server, then make sure they are all present."""
349+
test_file_data = 'test'
350+
local_test_path = 'directory_test'
351+
remote_test_path = 'directory_test_copied'
352+
for path in [local_test_path, remote_test_path]:
353+
try:
354+
shutil.rmtree(path)
355+
except OSError:
356+
pass
357+
os.mkdir(local_test_path)
358+
remote_file_paths = []
359+
for i in range(0, 10):
360+
local_file_path = os.path.join(local_test_path, 'foo' + str(i))
361+
remote_file_path = os.path.join(remote_test_path, 'foo' + str(i))
362+
remote_file_paths.append(remote_file_path)
363+
test_file = open(local_file_path, 'w')
364+
test_file.write(test_file_data)
365+
test_file.close()
366+
client = ParallelSSHClient([self.host], port=self.listen_port,
367+
pkey=self.user_key)
368+
cmds = client.copy_file(local_test_path, remote_test_path, recurse=True)
369+
for cmd in cmds:
370+
cmd.get()
371+
for path in remote_file_paths:
372+
self.assertTrue(os.path.isfile(path))
373+
shutil.rmtree(local_test_path)
374+
shutil.rmtree(remote_test_path)
375+
345376
def test_pssh_pool_size(self):
346377
"""Test pool size logic"""
347378
hosts = ['host-%01d' % d for d in xrange(5)]

tests/test_ssh_client.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,11 @@ def test_ssh_client_directory(self):
147147
test_file_data = 'test'
148148
local_test_path = 'directory_test'
149149
remote_test_path = 'directory_test_copied'
150+
for path in [local_test_path, remote_test_path]:
151+
try:
152+
shutil.rmtree(path)
153+
except OSError:
154+
pass
150155
os.mkdir(local_test_path)
151156
remote_file_paths = []
152157
for i in range(0, 10):
@@ -170,6 +175,11 @@ def test_ssh_client_directory_no_recurse(self):
170175
test_file_data = 'test'
171176
local_test_path = 'directory_test'
172177
remote_test_path = 'directory_test_copied'
178+
for path in [local_test_path, remote_test_path]:
179+
try:
180+
shutil.rmtree(path)
181+
except OSError:
182+
pass
173183
os.mkdir(local_test_path)
174184
remote_file_paths = []
175185
for i in range(0, 10):

0 commit comments

Comments
 (0)