Rename the connection class to serial_connection
[nitsi.git] / test.py
1 #!/usr/bin/python3
2
3 import serial
4
5 import re
6 from time import sleep
7 import sys
8
9 import libvirt
10
11 import xml.etree.ElementTree as ET
12
13 import os
14
15 import configparser
16
17 from disk import disk
18
19 class log():
20     def __init__(self, log_level):
21         self.log_level = log_level
22
23     def debug(self, string):
24         if self.log_level >= 4:
25             print("DEBUG: {}".format(string))
26
27     def error(self, string):
28         print("ERROR: {}".format(string))
29
30 class libvirt_con():
31     def __init__(self, uri):
32         self.log = log(4)
33         self.uri = uri
34         self.connection = None
35
36     def get_domain_from_name(self, name):
37         dom = self.con.lookupByName(name)
38
39         if dom == None:
40             raise BaseException
41         return dom
42
43     @property
44     def con(self):
45         if self.connection == None:
46             try:
47                 self.connection = libvirt.open(self.uri)
48             except BaseException as error:
49                 self.log.error("Could not connect to: {}".format(self.uri))
50
51             self.log.debug("Connected to: {}".format(self.uri))
52             return self.connection
53
54         return self.connection
55
56
57 class machine():
58     def __init__(self, vm_xml_file, snapshot_xml_file, image, root_uid, username, password):
59         self.log = log(4)
60         self.con = libvirt_con("qemu:///system")
61         try:
62             with open(vm_xml_file) as fobj:
63                 self.vm_xml = fobj.read()
64         except FileNotFoundError as error:
65             self.log.error("No such file: {}".format(vm_xml_file))
66
67         try:
68             with open(snapshot_xml_file) as fobj:
69                 self.snapshot_xml = fobj.read()
70         except FileNotFoundError as error:
71             self.log.error("No such file: {}".format(snapshot_xml_file))
72
73         self.image = image
74
75         if not os.path.isfile(self.image):
76             self.log.error("No such file: {}".format(self.image))
77
78         self.root_uid = root_uid
79         self.disk = disk(image)
80
81         self.username = username
82         self.password = password
83
84     def define(self):
85         self.dom = self.con.con.defineXML(self.vm_xml)
86         if self.dom == None:
87             self.log.error("Could not define VM")
88             raise BaseException
89
90     def start(self):
91         if self.dom.create() < 0:
92             self.log.error("Could not start VM")
93             raise BaseException
94
95     def shutdown(self):
96         if self.is_running():
97             if self.dom.shutdown() < 0:
98                 self.log.error("Could not shutdown VM")
99                 raise BaseException
100         else:
101             self.log.error("Domain is not running")
102
103     def undefine(self):
104         self.dom.undefine()
105
106     def create_snapshot(self):
107
108         self.snapshot = self.dom.snapshotCreateXML(self.snapshot_xml)
109
110         if not self.snapshot:
111             self.log.error("Could not create snapshot")
112             raise BaseException
113
114     def revert_snapshot(self):
115         self.dom.revertToSnapshot(self.snapshot)
116         self.snapshot.delete()
117
118     def is_running(self):
119
120         state, reason = self.dom.state()
121
122         if state == libvirt.VIR_DOMAIN_RUNNING:
123             return True
124         else:
125             return False
126
127     def get_serial_device(self):
128
129         if not self.is_running():
130             raise BaseException
131
132         xml_root = ET.fromstring(self.dom.XMLDesc(0))
133
134         elem = xml_root.find("./devices/serial/source")
135         return elem.get("path")
136
137     def check_is_booted_up(self):
138         serial_con = serial_connection(self.get_serial_device())
139
140         serial_con.write("\n")
141         # This will block till the domain is booted up
142         serial_con.read(1)
143
144         #serial_con.close()
145
146     def login(self):
147         try:
148             self.serial_con = serial_connection(self.get_serial_device(), username=self.username)
149             self.serial_con.login(self.password)
150         except BaseException as e:
151             self.log.error("Could not connect to the domain via serial console")
152
153     def cmd(self, cmd):
154         return self.serial_con.command(cmd)
155
156     def copy_in(self, fr, to):
157         try:
158             self.disk.mount(self.root_uid, "/")
159             self.disk.copy_in(fr, to)
160         except BaseException as e:
161             self.log.error(e)
162         finally:
163             self.disk.umount("/")
164             self.disk.close()
165
166 class serial_connection():
167     def __init__(self, device, username=None):
168         self.buffer = b""
169         self.back_at_prompt_pattern =  None
170         self.username = username
171         self.log = log(1)
172         self.con = serial.Serial(device)
173
174     def read(self, size=1):
175         if len(self.buffer) >= size:
176             # throw away first size bytes in buffer
177             data =  self.buffer[:size]
178             # Set the buffer to the non used bytes
179             self.buffer = self.buffer[size:]
180             return data
181         else:
182             data = self.buffer
183             # Set the size to the value we have to read now
184             size = size - len(self.buffer)
185             # Set the buffer empty
186             self.buffer = b""
187             return data + self.con.read(size)
188
189     def peek(self, size=1):
190         if len(self.buffer) <= size:
191             self.buffer += self.con.read(size=size - len(self.buffer))
192
193         return self.buffer[:size]
194
195     def readline(self):
196         self.log.debug(self.buffer)
197         self.buffer = self.buffer + self.con.read(self.con.in_waiting)
198         if b"\n" in self.buffer:
199             size = self.buffer.index(b"\n") + 1
200             self.log.debug("We have a whole line in the buffer")
201             self.log.debug(self.buffer)
202             self.log.debug("We split at {}".format(size))
203             data = self.buffer[:size]
204             self.buffer = self.buffer[size:]
205             self.log.debug(data)
206             self.log.debug(self.buffer)
207             return data
208
209         data = self.buffer
210         self.buffer = b""
211         return data + self.con.readline()
212
213     def back_at_prompt(self):
214         data = self.peek()
215         if not data == b"[":
216             return False
217
218         # We need to use self.in_waiting because with self.con.in_waiting we get
219         # not the complete string
220         size = len(self.buffer) + self.in_waiting
221         data = self.peek(size)
222
223
224         if self.back_at_prompt_pattern == None:
225             #self.back_at_prompt_pattern = r"^\[{}@.+\]#".format(self.username)
226             self.back_at_prompt_pattern = re.compile(r"^\[{}@.+\]#".format(self.username), re.MULTILINE)
227
228         if self.back_at_prompt_pattern.search(data.decode()):
229             return True
230         else:
231             return False
232
233     def log_console_line(self, line):
234         self.log.debug("Get in function log_console_line()")
235         sys.stdout.write(line)
236
237     @property
238     def in_waiting(self):
239         in_waiting_before = 0
240         sleep(0.5)
241
242         while in_waiting_before != self.con.in_waiting:
243             in_waiting_before = self.con.in_waiting
244             sleep(0.5)
245
246         return self.con.in_waiting
247
248     def line_in_buffer(self):
249         if b"\n" in self.buffer:
250             return True
251
252         return False
253
254     def print_lines_in_buffer(self):
255         while True:
256             self.log.debug("Fill buffer ...")
257             self.peek(len(self.buffer) + self.in_waiting)
258             self.log.debug("Current buffer length: {}".format(len(self.buffer)))
259             if self.line_in_buffer() == True:
260                 while self.line_in_buffer() == True:
261                     data = self.readline()
262                     self.log_console_line(data.decode())
263             else:
264                 self.log.debug("We have printed all lines in the buffer")
265                 break
266
267     def login(self, password):
268         if self.username == None:
269             self.log.error("Username cannot be blank")
270             return False
271
272         self.print_lines_in_buffer()
273
274         # Hit enter to see what we get
275         self.con.write(b'\n')
276         # We get two new lines \r\n ?
277         data = self.readline()
278         self.log_console_line(data.decode())
279
280         self.print_lines_in_buffer()
281
282         if self.back_at_prompt():
283             self.log.debug("We are already logged in.")
284             return True
285
286         # Read all line till we get login:
287         while 1:
288             # We need to use self.in_waiting because with self.con.in_waiting we get
289             # not the complete string
290             size = len(self.buffer) + self.in_waiting
291             data = self.peek(size)
292
293             pattern = r"^.*login: "
294             pattern = re.compile(pattern)
295
296             if pattern.search(data.decode()):
297                 break
298             else:
299                 self.log.debug("The pattern does not match")
300                 self.log_console_line(self.readline().decode())
301
302         # We can login
303         string = "{}\n".format(self.username)
304         self.con.write(string.encode())
305         self.con.flush()
306         # read the login out of the buffer
307         data = self.readline()
308         self.log.debug("This is the login:{}".format(data))
309         self.log_console_line(data.decode())
310
311         # We need to wait her till we get the full string "Password:"
312         #This is useless but self.in_waiting will wait the correct amount of time
313         size = self.in_waiting
314
315         string = "{}\n".format(password)
316         self.con.write(string.encode())
317         self.con.flush()
318         # Print the 'Password:' line
319         data = self.readline()
320         self.log_console_line(data.decode())
321
322         while not self.back_at_prompt():
323             # This will fail if the login failed so we need to look for the failed keyword
324             data = self.readline()
325             self.log_console_line(data.decode())
326
327         return True
328
329     def write(self, string):
330         self.log.debug(string)
331         self.con.write(string.encode())
332         self.con.flush()
333
334     def command(self, command):
335         self.write("{}; echo \"END: $?\"\n".format(command))
336
337         # We need to read out the prompt for this command first
338         # If we do not do this we will break the loop immediately
339         # because the prompt for this command is still in the buffer
340         data = self.readline()
341         self.log_console_line(data.decode())
342
343         while not self.back_at_prompt():
344             data = self.readline()
345             self.log_console_line(data.decode())
346
347         # We saved our exit code in data (the last line)
348         self.log.debug(data.decode())
349         data = data.decode().replace("END: ", "")
350         self.log.debug(data)
351         self.log.debug(data.strip())
352         return data.strip()
353
354
355 # A class which define and undefine a virtual network based on an xml file
356 class network():
357     def __init__(self, network_xml_file):
358         self.log = log(4)
359         self.con = libvirt_con("qemu:///system")
360         try:
361             with open(network_xml_file) as fobj:
362                 self.network_xml = fobj.read()
363         except FileNotFoundError as error:
364             self.log.error("No such file: {}".format(vm_xml_file))
365
366     def define(self):
367         self.network = self.con.con.networkDefineXML(self.network_xml)
368
369         if network == None:
370             self.log.error("Failed to define virtual network")
371
372     def start(self):
373         self.network.create()
374
375     def undefine(self):
376         self.network.destroy()
377
378
379
380 class RecipeExeption(Exception):
381     pass
382
383
384
385 # Should read the test, check if the syntax are valid
386 # and return tuples with the ( host, command ) structure
387 class recipe():
388     def __init__(self, path, circle=[]):
389         self.log = log(4)
390         self.recipe_file = path
391         self.path = os.path.dirname(self.recipe_file)
392         self.log.debug("Path of recipe is: {}".format(self.recipe_file))
393         self._recipe = None
394         self._machines = None
395
396         self.in_recursion = True
397         if len(circle) == 0:
398             self.in_recursion = False
399
400         self.circle = circle
401         self.log.debug(circle)
402         self.log.debug(self.circle)
403
404         if not os.path.isfile(self.recipe_file):
405             self.log.error("No such file: {}".format(self.recipe_file))
406
407         try:
408             with open(self.recipe_file) as fobj:
409                 self.raw_recipe = fobj.readlines()
410         except FileNotFoundError as error:
411             self.log.error("No such file: {}".format(vm_xml_file))
412
413     @property
414     def recipe(self):
415         if not self._recipe:
416             self.parse()
417
418         return self._recipe
419
420     @property
421     def machines(self):
422         if not self._machines:
423             self._machines = []
424             for line in self._recipe:
425                 if line[0] != "all" and line[0] not in self._machines:
426                     self._machines.append(line[0])
427
428         return self._machines
429
430     def parse(self):
431         self._recipe = []
432         i = 1
433         for line in self.raw_recipe:
434             raw_line = line.split(":")
435             if len(raw_line) < 2:
436                 self.log.error("Error parsing the recipe in line {}".format(i))
437                 raise RecipeExeption
438             cmd = raw_line[1].strip()
439             raw_line = raw_line[0].strip().split(" ")
440             if len(raw_line) == 0:
441                 self.log.error("Failed to parse the recipe in line {}".format(i))
442                 raise RecipeExeption
443
444             if raw_line[0].strip() == "":
445                     self.log.error("Failed to parse the recipe in line {}".format(i))
446                     raise RecipeExeption
447
448             machine = raw_line[0].strip()
449
450             if len(raw_line) == 2:
451                 extra = raw_line[1].strip()
452             else:
453                 extra = ""
454
455             # We could get a machine here or a include statement
456             if machine == "include":
457                 path = cmd.strip()
458                 path = os.path.normpath(self.path + "/" + path)
459                 path = path + "/recipe"
460                 if path in self.circle:
461                     self.log.error("Detect import loop!")
462                     raise RecipeExeption
463                 self.circle.append(path)
464                 recipe_to_include = recipe(path, circle=self.circle)
465
466             if machine == "include":
467                 self._recipe.extend(recipe_to_include.recipe)
468             else:
469                 # Support also something like 'alice,bob: echo'
470                 machines = machine.split(",")
471                 for machine in machines:
472                     self._recipe.append((machine.strip(), extra.strip(), cmd.strip()))
473             i = i + 1
474
475             if not self.in_recursion:
476                 tmp_recipe = []
477                 for line in self._recipe:
478                     if line[0] != "all":
479                         tmp_recipe.append(line)
480                     else:
481                         for machine in self.machines:
482                             tmp_recipe.append((machine.strip(), line[1], line[2]))
483
484                 self._recipe = tmp_recipe
485
486
487
488 class test():
489     def __init__(self, path):
490         self.log = log(4)
491         try:
492             self.path = os.path.abspath(path)
493         except BaseException as e:
494             self.log.error("Could not get absolute path")
495
496         self.log.debug(self.path)
497
498         self.settings_file = "{}/settings".format(self.path)
499         if not os.path.isfile(self.settings_file):
500             self.log.error("No such file: {}".format(self.settings_file))
501
502         self.recipe_file = "{}/recipe".format(self.path)
503         if not os.path.isfile(self.recipe_file):
504             self.log.error("No such file: {}".format(self.recipe_file))
505
506     def read_settings(self):
507         self.config = configparser.ConfigParser()
508         self.config.read(self.settings_file)
509         self.name = self.config["DEFAULT"]["Name"]
510         self.description = self.config["DEFAULT"]["Description"]
511         self.copy_to = self.config["DEFAULT"]["Copy_to"]
512         self.copy_from = self.config["DEFAULT"]["Copy_from"]
513         self.copy_from = self.copy_from.split(",")
514
515         tmp = []
516         for file in self.copy_from:
517             file = file.strip()
518             file = os.path.normpath(self.path + "/" + file)
519             tmp.append(file)
520
521         self.copy_from = tmp
522
523         self.virtual_environ_name = self.config["VIRTUAL_ENVIRONMENT"]["Name"]
524         self.virtual_environ_path = self.config["VIRTUAL_ENVIRONMENT"]["Path"]
525         self.virtual_environ_path = os.path.normpath(self.path + "/" + self.virtual_environ_path)
526
527     def virtual_environ_setup(self):
528         self.virtual_environ = virtual_environ(self.virtual_environ_path)
529
530         self.virtual_networks = self.virtual_environ.get_networks()
531
532         self.virtual_machines = self.virtual_environ.get_machines()
533
534     def virtual_environ_start(self):
535         for name in self.virtual_environ.network_names:
536             self.virtual_networks[name].define()
537             self.virtual_networks[name].start()
538
539         for name in self.virtual_environ.machine_names:
540             self.virtual_machines[name].define()
541             self.virtual_machines[name].create_snapshot()
542             self.virtual_machines[name].copy_in(self.copy_from, self.copy_to)
543             self.virtual_machines[name].start()
544
545         self.log.debug("Try to login on all machines")
546         for name in self.virtual_environ.machine_names:
547             self.virtual_machines[name].login()
548
549     def load_recipe(self):
550         try:
551             self.recipe = recipe(self.recipe_file)
552         except BaseException:
553             self.log.error("Failed to load recipe")
554             exit(1)
555
556     def run_recipe(self):
557         for line in self.recipe.recipe:
558             return_value = self.virtual_machines[line[0]].cmd(line[2])
559             self.log.debug("Return value is: {}".format(return_value))
560             if return_value != "0" and line[1] == "":
561                 self.log.error("Failed to execute command '{}' on {}, return code: {}".format(line[2],line[0], return_value))
562                 return False
563             elif return_value == "0" and line[1] == "!":
564                 self.log.error("Succeded to execute command '{}' on {}, return code: {}".format(line[2],line[0],return_value))
565                 return False
566             else:
567                 self.log.debug("Command '{}' on {} returned with: {}".format(line[2],line[0],return_value))
568
569     def virtual_environ_stop(self):
570         for name in self.virtual_environ.machine_names:
571             self.virtual_machines[name].shutdown()
572             self.virtual_machines[name].revert_snapshot()
573             self.virtual_machines[name].undefine()
574
575         for name in self.virtual_environ.network_names:
576             self.virtual_networks[name].undefine()
577
578
579 # Should return all vms and networks in a list
580 # and should provide the path to the necessary xml files
581 class virtual_environ():
582     def __init__(self, path):
583         self.log = log(4)
584         try:
585             self.path = os.path.abspath(path)
586         except BaseException as e:
587             self.log.error("Could not get absolute path")
588
589         self.log.debug(self.path)
590
591         self.settings_file = "{}/settings".format(self.path)
592         if not os.path.isfile(self.settings_file):
593             self.log.error("No such file: {}".format(self.settings_file))
594
595         self.log.debug(self.settings_file)
596         self.config = configparser.ConfigParser()
597         self.config.read(self.settings_file)
598         self.name = self.config["DEFAULT"]["name"]
599         self.machines_string = self.config["DEFAULT"]["machines"]
600         self.networks_string = self.config["DEFAULT"]["networks"]
601
602         self.machines = []
603         for machine in self.machines_string.split(","):
604             self.machines.append(machine.strip())
605
606         self.networks = []
607         for network in self.networks_string.split(","):
608             self.networks.append(network.strip())
609
610         self.log.debug(self.machines)
611         self.log.debug(self.networks)
612
613     def get_networks(self):
614         networks = {}
615         for _network in self.networks:
616             self.log.debug(_network)
617             networks.setdefault(_network, network(os.path.normpath(self.path + "/" + self.config[_network]["xml_file"])))
618         return networks
619
620     def get_machines(self):
621         machines = {}
622         for _machine in self.machines:
623             self.log.debug(_machine)
624             machines.setdefault(_machine, machine(
625                 os.path.normpath(self.path + "/" + self.config[_machine]["xml_file"]),
626                 os.path.normpath(self.path + "/" + self.config[_machine]["snapshot_xml_file"]),
627                 self.config[_machine]["image"],
628                 self.config[_machine]["root_uid"],
629                 self.config[_machine]["username"],
630                 self.config[_machine]["password"]))
631
632         return machines
633
634     @property
635     def machine_names(self):
636         return self.machines
637
638     @property
639     def network_names(self):
640         return self.networks
641
642
643 if __name__ == "__main__":
644     import argparse
645
646     parser = argparse.ArgumentParser()
647
648     parser.add_argument("-d", "--directory", dest="dir")
649
650     args = parser.parse_args()
651
652     currenttest = test(args.dir)
653     currenttest.read_settings()
654     currenttest.virtual_environ_setup()
655     currenttest.load_recipe()
656     try:
657         currenttest.virtual_environ_start()
658         currenttest.run_recipe()
659     except BaseException as e:
660         print(e)
661     finally:
662         currenttest.virtual_environ_stop()
663