@@ -54,33 +54,28 @@ def __init__(
5454 ):
5555 self .host = host
5656 self .user = user
57- self .ssh_password = ssh_password if key_path is None else None
57+ self .ssh_password = ssh_password if not key_path else None
5858 self .port = port
5959 self .nb_bytes = 1024
60- self .keys , self .transport = [], None
60+ self .keys = []
61+ self .transport = None
6162 key_type = key_type .lower ()
6263
6364 if key_path :
64- self .keys .append (
65- _KEY_TYPES [key_type ].from_private_key (
66- open (path .expanduser (key_path ), 'r' ),
67- key_password ,
68- )
69- )
65+ key_file = open (path .expanduser (key_path ), 'r' )
66+ key = _KEY_TYPES [key_type ].from_private_key (key_file , key_password )
67+ self .keys .append (key )
7068 elif ssh_password is None :
7169 self .keys = paramiko .Agent ().get_keys ()
72-
7370 try :
74- key_file = _KEY_TYPES [key_type ].from_private_key (
75- open (path .expanduser (f"~/.ssh/id_{ key_type } " ), 'r' ),
76- key_password
77- )
71+ key_file = open (path .expanduser (f"~/.ssh/id_{ key_type } " ), 'r' )
72+ key = _KEY_TYPES [key_type ].from_private_key (
73+ key_file , key_password )
7874 except Exception :
7975 pass
8076 else :
81- self .keys .insert (
82- len (self .keys ) if key_password is None else 0 , key_file
83- )
77+ index = len (self .keys ) if key_password is None else 0
78+ self .keys .insert (index , key )
8479
8580 if not self .keys :
8681 logging .error ("No valid key found" )
@@ -96,10 +91,8 @@ def connect(self):
9691
9792 if self .ssh_password is not None :
9893 try :
99- self .transport .connect (
100- username = self .user ,
101- password = self .ssh_password ,
102- )
94+ self .transport .connect (username = self .user ,
95+ password = self .ssh_password )
10396 except paramiko .SSHException :
10497 pass
10598 else :
@@ -117,27 +110,24 @@ def connect(self):
117110 logging .info (f"Successfully connected to { self .user } @{ self .host } " )
118111 return 0
119112
120- def __run_until_event (
113+ def _run_until_event (
121114 self ,
122115 command ,
123116 stop_event ,
124117 display = True ,
125- capture_output = False ,
118+ capture = False ,
126119 shell = True ,
127120 combine_stderr = False ,
128121 ):
122+ exit_code , output = 0 , ""
129123 channel = self .transport .open_session ()
130- output = ""
131-
132124 channel .settimeout (2 )
133125 channel .set_combine_stderr (combine_stderr )
134-
135126 if shell :
136127 channel .get_pty ()
137-
138128 channel .exec_command (command )
139129
140- if not display and not capture_output :
130+ if not display and not capture :
141131 stop_event .wait ()
142132 else :
143133 while True :
@@ -148,102 +138,92 @@ def __run_until_event(
148138 break
149139 continue
150140
151- if not len ( raw_data ) :
141+ if not raw_data :
152142 break
153-
154143 data = raw_data .decode ("utf-8" )
155-
156144 if display :
157145 print (data , end = '' )
158-
159- if capture_output :
146+ if capture :
160147 output += data
161-
162148 if stop_event .is_set ():
163149 break
164150
165151 channel .close ()
166152
167- if not channel .exit_status_ready ():
168- return ( 0 , output . splitlines () )
153+ if channel .exit_status_ready ():
154+ exit_code = channel . recv_exit_status ( )
169155
170- return (channel . recv_exit_status () , output .splitlines ())
156+ return (exit_code , output .splitlines ())
171157
172- def __run_until_exit (
158+ def _run_until_exit (
173159 self ,
174160 command ,
175161 timeout ,
176162 display = True ,
177- capture_output = False ,
163+ capture = False ,
178164 shell = True ,
179165 combine_stderr = False ,
180166 ):
167+ exit_code , output = 0 , ""
181168 channel = self .transport .open_session ()
182- output = ""
183-
184169 channel .settimeout (timeout )
185170 channel .set_combine_stderr (combine_stderr )
186-
187171 if shell :
188172 channel .get_pty ()
189-
190173 channel .exec_command (command )
191174
192175 try :
193- if not display and not capture_output :
176+ if not display and not capture :
194177 return (channel .recv_exit_status (), output .splitlines ())
195178 else :
196179 while True :
197180 raw_data = channel .recv (self .nb_bytes )
198-
199- if not len (raw_data ):
181+ if not raw_data :
200182 break
201-
202183 data = raw_data .decode ("utf-8" )
203-
204184 if display :
205185 print (data , end = '' )
206-
207- if capture_output :
186+ if capture :
208187 output += data
209188 except socket .timeout :
210189 logging .warning (f"Timeout after { timeout } s" )
211- return ( 1 , output . splitlines ())
190+ exit_code = 1
212191 except KeyboardInterrupt :
213192 logging .info ("KeyboardInterrupt" )
214- return (0 , output .splitlines ())
193+ exit_code = 0
194+ else :
195+ exit_code = channel .recv_exit_status ()
215196 finally :
216197 channel .close ()
217-
218- return (channel .recv_exit_status (), output .splitlines ())
198+ return (exit_code , output .splitlines ())
219199
220200 def run (
221201 self ,
222202 command ,
223203 display = False ,
224- capture_output = False ,
204+ capture = False ,
225205 shell = True ,
226206 combine_stderr = False ,
227207 timeout = None ,
228208 stop_event = None ,
229209 ):
230- if stop_event :
231- return self .__run_until_event (
210+ if stop_event is not None :
211+ return self ._run_until_event (
232212 command ,
233213 stop_event ,
234214 display = display ,
235215 shell = shell ,
236216 combine_stderr = combine_stderr ,
237- capture_output = capture_output ,
217+ capture = capture ,
238218 )
239219 else :
240- return self .__run_until_exit (
220+ return self ._run_until_exit (
241221 command ,
242222 timeout ,
243223 display = display ,
244224 shell = shell ,
245225 combine_stderr = combine_stderr ,
246- capture_output = capture_output ,
226+ capture = capture ,
247227 )
248228
249229 def disconnect (self ):
@@ -254,12 +234,11 @@ def __getattr__(self, target):
254234 def wrapper (* args , ** kwargs ):
255235 if not self .transport .is_authenticated ():
256236 logging .error ("SSH session is not ready" )
257- return 1
237+ return
258238
259239 sftp_channel = SFTPController .from_transport (self .transport )
260240 r = getattr (sftp_channel , target )(* args , ** kwargs )
261241 sftp_channel .close ()
262-
263242 return r
264243
265244 return wrapper
0 commit comments