Add support for the all: statement
[nitsi.git] / test.py
1 #!/usr/bin/python3
2
3 import serial
4
5 import re
6 from time import sleep
7 import sys
8
9 import libvirt
10
11 import xml.etree.ElementTree as ET
12
13 import os
14
15 import configparser
16
17 class log():
18     def __init__(self, log_level):
19         self.log_level = log_level
20
21     def debug(self, string):
22         if self.log_level >= 4:
23             print("DEBUG: {}".format(string))
24
25     def error(self, string):
26         print("ERROR: {}".format(string))
27
28 class libvirt_con():
29     def __init__(self, uri):
30         self.log = log(4)
31         self.uri = uri
32         self.connection = None
33
34     def get_domain_from_name(self, name):
35         dom = self.con.lookupByName(name)
36
37         if dom == None:
38             raise BaseException
39         return dom
40
41     @property
42     def con(self):
43         if self.connection == None:
44             try:
45                 self.connection = libvirt.open(self.uri)
46             except BaseException as error:
47                 self.log.error("Could not connect to: {}".format(self.uri))
48
49             self.log.debug("Connected to: {}".format(self.uri))
50             return self.connection
51
52         return self.connection
53
54
55 class vm():
56     def __init__(self, vm_xml_file, snapshot_xml_file, image, root_uid, username, password):
57         self.log = log(4)
58         self.con = libvirt_con("qemu:///system")
59         try:
60             with open(vm_xml_file) as fobj:
61                 self.vm_xml = fobj.read()
62         except FileNotFoundError as error:
63             self.log.error("No such file: {}".format(vm_xml_file))
64
65         try:
66             with open(snapshot_xml_file) as fobj:
67                 self.snapshot_xml = fobj.read()
68         except FileNotFoundError as error:
69             self.log.error("No such file: {}".format(snapshot_xml_file))
70
71         self.image = image
72
73         if not os.path.isfile(self.image):
74             self.log.error("No such file: {}".format(self.image))
75
76         self.root_uid = root_uid
77
78         self.username = username
79         self.password = password
80
81     def define(self):
82         self.dom = self.con.con.defineXML(self.vm_xml)
83         if self.dom == None:
84             self.log.error("Could not define VM")
85             raise BaseException
86
87     def start(self):
88         if self.dom.create() < 0:
89             self.log.error("Could not start VM")
90             raise BaseException
91
92     def shutdown(self):
93         if self.is_running():
94             if self.dom.shutdown() < 0:
95                 self.log.error("Could not shutdown VM")
96                 raise BaseException
97         else:
98             self.log.error("Domain is not running")
99
100     def undefine(self):
101         self.dom.undefine()
102
103     def create_snapshot(self):
104
105         self.snapshot = self.dom.snapshotCreateXML(self.snapshot_xml)
106
107         if not self.snapshot:
108             self.log.error("Could not create snapshot")
109             raise BaseException
110
111     def revert_snapshot(self):
112         self.dom.revertToSnapshot(self.snapshot)
113         self.snapshot.delete()
114
115     def is_running(self):
116
117         state, reason = self.dom.state()
118
119         if state == libvirt.VIR_DOMAIN_RUNNING:
120             return True
121         else:
122             return False
123
124     def get_serial_device(self):
125
126         if not self.is_running():
127             raise BaseException
128
129         xml_root = ET.fromstring(self.dom.XMLDesc(0))
130
131         elem = xml_root.find("./devices/serial/source")
132         return elem.get("path")
133
134     def check_is_booted_up(self):
135         serial_con = connection(self.get_serial_device())
136
137         serial_con.write("\n")
138         # This will block till the domain is booted up
139         serial_con.read(1)
140
141         #serial_con.close()
142
143     def login(self):
144         try:
145             self.serial_con = connection(self.get_serial_device(), username=self.username)
146             self.serial_con.login(self.password)
147         except BaseException as e:
148             self.log.error("Could not connect to the domain via serial console")
149
150     def cmd(self, cmd):
151         return self.serial_con.command(cmd)
152
153
154 class connection():
155     def __init__(self, device, username=None):
156         self.buffer = b""
157         self.back_at_prompt_pattern =  None
158         self.username = username
159         self.log = log(1)
160         self.con = serial.Serial(device)
161
162     def read(self, size=1):
163         if len(self.buffer) >= size:
164             # throw away first size bytes in buffer
165             data =  self.buffer[:size]
166             # Set the buffer to the non used bytes
167             self.buffer = self.buffer[size:]
168             return data
169         else:
170             data = self.buffer
171             # Set the size to the value we have to read now
172             size = size - len(self.buffer)
173             # Set the buffer empty
174             self.buffer = b""
175             return data + self.con.read(size)
176
177     def peek(self, size=1):
178         if len(self.buffer) <= size:
179             self.buffer += self.con.read(size=size - len(self.buffer))
180
181         return self.buffer[:size]
182
183     def readline(self):
184         self.log.debug(self.buffer)
185         self.buffer = self.buffer + self.con.read(self.con.in_waiting)
186         if b"\n" in self.buffer:
187             size = self.buffer.index(b"\n") + 1
188             self.log.debug("We have a whole line in the buffer")
189             self.log.debug(self.buffer)
190             self.log.debug("We split at {}".format(size))
191             data = self.buffer[:size]
192             self.buffer = self.buffer[size:]
193             self.log.debug(data)
194             self.log.debug(self.buffer)
195             return data
196
197         data = self.buffer
198         self.buffer = b""
199         return data + self.con.readline()
200
201     def back_at_prompt(self):
202         data = self.peek()
203         if not data == b"[":
204             return False
205
206         # We need to use self.in_waiting because with self.con.in_waiting we get
207         # not the complete string
208         size = len(self.buffer) + self.in_waiting
209         data = self.peek(size)
210
211
212         if self.back_at_prompt_pattern == None:
213             #self.back_at_prompt_pattern = r"^\[{}@.+\]#".format(self.username)
214             self.back_at_prompt_pattern = re.compile(r"^\[{}@.+\]#".format(self.username), re.MULTILINE)
215
216         if self.back_at_prompt_pattern.search(data.decode()):
217             return True
218         else:
219             return False
220
221     def log_console_line(self, line):
222         self.log.debug("Get in function log_console_line()")
223         sys.stdout.write(line)
224
225     @property
226     def in_waiting(self):
227         in_waiting_before = 0
228         sleep(0.5)
229
230         while in_waiting_before != self.con.in_waiting:
231             in_waiting_before = self.con.in_waiting
232             sleep(0.5)
233
234         return self.con.in_waiting
235
236     def line_in_buffer(self):
237         if b"\n" in self.buffer:
238             return True
239
240         return False
241
242     def readline2(self, pattern=None):
243         string = ""
244         string2 = b""
245         if pattern:
246             pattern = re.compile(pattern)
247
248         while 1:
249             char = self.con.read(1)
250             string = string + char.decode("utf-8")
251             string2 = string2 + char
252             #print(char)
253             print(char.decode("utf-8"), end="")
254
255             #print(string2)
256             if pattern and pattern.match(string):
257                #print("get here1")
258                #print(string2)
259                return {"string" : string, "return-code" : 1}
260
261             if char == b"\n":
262                 #print(char)
263                 #print(string2)
264                 #print("get here2")
265                 return {"return-code" : 0}
266
267     def check_logged_in(self, username):
268         pattern = "^\[" + username + "@.+\]#"
269         data = self.readline(pattern=pattern)
270         if data["return-code"] == 1:
271                 print("We are logged in")
272                 return True
273         else:
274             print("We are  not logged in")
275             return False
276
277     def print_lines_in_buffer(self):
278         while True:
279             self.log.debug("Fill buffer ...")
280             self.peek(len(self.buffer) + self.in_waiting)
281             self.log.debug("Current buffer length: {}".format(len(self.buffer)))
282             if self.line_in_buffer() == True:
283                 while self.line_in_buffer() == True:
284                     data = self.readline()
285                     self.log_console_line(data.decode())
286             else:
287                 self.log.debug("We have printed all lines in the buffer")
288                 break
289
290     def login(self, password):
291         if self.username == None:
292             self.log.error("Username cannot be blank")
293             return False
294
295         self.print_lines_in_buffer()
296
297         # Hit enter to see what we get
298         self.con.write(b'\n')
299         # We get two new lines \r\n ?
300         data = self.readline()
301         self.log_console_line(data.decode())
302
303         self.print_lines_in_buffer()
304
305         if self.back_at_prompt():
306             self.log.debug("We are already logged in.")
307             return True
308
309         # Read all line till we get login:
310         while 1:
311             # We need to use self.in_waiting because with self.con.in_waiting we get
312             # not the complete string
313             size = len(self.buffer) + self.in_waiting
314             data = self.peek(size)
315
316             pattern = r"^.*login: "
317             pattern = re.compile(pattern)
318
319             if pattern.search(data.decode()):
320                 break
321             else:
322                 self.log.debug("The pattern does not match")
323                 self.log_console_line(self.readline().decode())
324
325         # We can login
326         string = "{}\n".format(self.username)
327         self.con.write(string.encode())
328         self.con.flush()
329         # read the login out of the buffer
330         data = self.readline()
331         self.log.debug("This is the login:{}".format(data))
332         self.log_console_line(data.decode())
333
334         # We need to wait her till we get the full string "Password:"
335         #This is useless but self.in_waiting will wait the correct amount of time
336         size = self.in_waiting
337
338         string = "{}\n".format(password)
339         self.con.write(string.encode())
340         self.con.flush()
341         # Print the 'Password:' line
342         data = self.readline()
343         self.log_console_line(data.decode())
344
345         while not self.back_at_prompt():
346             # This will fail if the login failed so we need to look for the failed keyword
347             data = self.readline()
348             self.log_console_line(data.decode())
349
350         return True
351
352     def write(self, string):
353         self.log.debug(string)
354         self.con.write(string.encode())
355         self.con.flush()
356
357     def command(self, command):
358         self.write("{}; echo \"END: $?\"\n".format(command))
359
360         # We need to read out the prompt for this command first
361         # If we do not do this we will break the loop immediately
362         # because the prompt for this command is still in the buffer
363         data = self.readline()
364         self.log_console_line(data.decode())
365
366         while not self.back_at_prompt():
367             data = self.readline()
368             self.log_console_line(data.decode())
369
370         # We saved our exit code in data (the last line)
371         self.log.debug(data.decode())
372         data = data.decode().replace("END: ", "")
373         self.log.debug(data)
374         self.log.debug(data.strip())
375         return data.strip()
376
377
378 # A class which define and undefine a virtual network based on an xml file
379 class network():
380     def __init__(self, network_xml_file):
381         self.log = log(4)
382         self.con = libvirt_con("qemu:///system")
383         try:
384             with open(network_xml_file) as fobj:
385                 self.network_xml = fobj.read()
386         except FileNotFoundError as error:
387             self.log.error("No such file: {}".format(vm_xml_file))
388
389     def define(self):
390         self.network = self.con.con.networkDefineXML(self.network_xml)
391
392         if network == None:
393             self.log.error("Failed to define virtual network")
394
395     def start(self):
396         self.network.create()
397
398     def undefine(self):
399         self.network.destroy()
400
401
402
403 class RecipeExeption(Exception):
404     pass
405
406
407
408 # Should read the test, check if the syntax are valid
409 # and return tuples with the ( host, command ) structure
410 class recipe():
411     def __init__(self, path, circle=[]):
412         self.log = log(4)
413         self.recipe_file = path
414         self.path = os.path.dirname(self.recipe_file)
415         self.log.debug("Path of recipe is: {}".format(self.recipe_file))
416         self._recipe = None
417         self._machines = None
418
419         self.in_recursion = True
420         if len(circle) == 0:
421             self.in_recursion = False
422
423         self.circle = circle
424         self.log.debug(circle)
425         self.log.debug(self.circle)
426
427         if not os.path.isfile(self.recipe_file):
428             self.log.error("No such file: {}".format(self.recipe_file))
429
430         try:
431             with open(self.recipe_file) as fobj:
432                 self.raw_recipe = fobj.readlines()
433         except FileNotFoundError as error:
434             self.log.error("No such file: {}".format(vm_xml_file))
435
436     @property
437     def recipe(self):
438         if not self._recipe:
439             self.parse()
440
441         return self._recipe
442
443     @property
444     def machines(self):
445         if not self._machines:
446             self._machines = []
447             for line in self._recipe:
448                 if line[0] != "all" and line[0] not in self._machines:
449                     self._machines.append(line[0])
450
451         return self._machines
452
453     def parse(self):
454         self._recipe = []
455         i = 1
456         for line in self.raw_recipe:
457             raw_line = line.split(":")
458             if len(raw_line) < 2:
459                 self.log.error("Error parsing the recipe in line {}".format(i))
460                 raise RecipeExeption
461             cmd = raw_line[1].strip()
462             raw_line = raw_line[0].strip().split(" ")
463             if len(raw_line) == 0:
464                 self.log.error("Failed to parse the recipe in line {}".format(i))
465                 raise RecipeExeption
466
467             if raw_line[0].strip() == "":
468                     self.log.error("Failed to parse the recipe in line {}".format(i))
469                     raise RecipeExeption
470
471             machine = raw_line[0].strip()
472
473             if len(raw_line) == 2:
474                 extra = raw_line[1].strip()
475             else:
476                 extra = ""
477
478             # We could get a machine here or a include statement
479             if machine == "include":
480                 path = cmd.strip()
481                 path = os.path.normpath(self.path + "/" + path)
482                 path = path + "/recipe"
483                 if path in self.circle:
484                     self.log.error("Detect import loop!")
485                     raise RecipeExeption
486                 self.circle.append(path)
487                 recipe_to_include = recipe(path, circle=self.circle)
488
489             if machine == "include":
490                 self._recipe.extend(recipe_to_include.recipe)
491             else:
492                 self._recipe.append((machine.strip(), extra.strip(), cmd.strip()))
493             i = i + 1
494
495             if not self.in_recursion:
496                 tmp_recipe = []
497                 for line in self._recipe:
498                     if line[0] != "all":
499                         tmp_recipe.append(line)
500                     else:
501                         for machine in self.machines:
502                             tmp_recipe.append((machine.strip(), line[1], line[2]))
503
504                 self._recipe = tmp_recipe
505
506
507
508 class test():
509     def __init__(self, path):
510         self.log = log(4)
511         try:
512             self.path = os.path.abspath(path)
513         except BaseException as e:
514             self.log.error("Could not get absolute path")
515
516         self.log.debug(self.path)
517
518         self.settings_file = "{}/settings".format(self.path)
519         if not os.path.isfile(self.settings_file):
520             self.log.error("No such file: {}".format(self.settings_file))
521
522         self.recipe_file = "{}/recipe".format(self.path)
523         if not os.path.isfile(self.recipe_file):
524             self.log.error("No such file: {}".format(self.recipe_file))
525
526     def read_settings(self):
527         self.config = configparser.ConfigParser()
528         self.config.read(self.settings_file)
529         self.name = self.config["DEFAULT"]["Name"]
530         self.description = self.config["DEFAULT"]["Description"]
531
532         self.virtual_environ_name = self.config["VIRTUAL_ENVIRONMENT"]["Name"]
533         self.virtual_environ_path = self.config["VIRTUAL_ENVIRONMENT"]["Path"]
534         self.virtual_environ_path = os.path.normpath(self.path + "/" + self.virtual_environ_path)
535
536     def virtual_environ_setup(self):
537         self.virtual_environ = virtual_environ(self.virtual_environ_path)
538
539         self.virtual_networks = self.virtual_environ.get_networks()
540
541         self.virtual_machines = self.virtual_environ.get_machines()
542
543     def virtual_environ_start(self):
544         for name in self.virtual_environ.network_names:
545             self.virtual_networks[name].define()
546             self.virtual_networks[name].start()
547
548         for name in self.virtual_environ.machine_names:
549             self.virtual_machines[name].define()
550             self.virtual_machines[name].create_snapshot()
551             self.virtual_machines[name].start()
552
553         self.log.debug("Try to login on all machines")
554         for name in self.virtual_environ.machine_names:
555             self.virtual_machines[name].login()
556
557     def load_recipe(self):
558         try:
559             self.recipe = recipe(self.recipe_file)
560         except BaseException:
561             self.log.error("Failed to load recipe")
562             exit(1)
563
564     def run_recipe(self):
565         for line in self.recipe.recipe:
566             return_value = self.virtual_machines[line[0]].cmd(line[2])
567             self.log.debug("Return value is: {}".format(return_value))
568             if return_value != "0" and line[1] == "":
569                 self.log.error("Failed to execute command '{}' on {}, return code: {}".format(line[2],line[0], return_value))
570                 return False
571             elif return_value == "0" and line[1] == "!":
572                 self.log.error("Succeded to execute command '{}' on {}, return code: {}".format(line[2],line[0],return_value))
573                 return False
574             else:
575                 self.log.debug("Command '{}' on {} returned with: {}".format(line[2],line[0],return_value))
576
577     def virtual_environ_stop(self):
578         for name in self.virtual_environ.machine_names:
579             self.virtual_machines[name].shutdown()
580             self.virtual_machines[name].revert_snapshot()
581             self.virtual_machines[name].undefine()
582
583         for name in self.virtual_environ.network_names:
584             self.virtual_networks[name].undefine()
585
586
587 # Should return all vms and networks in a list
588 # and should provide the path to the necessary xml files
589 class virtual_environ():
590     def __init__(self, path):
591         self.log = log(4)
592         try:
593             self.path = os.path.abspath(path)
594         except BaseException as e:
595             self.log.error("Could not get absolute path")
596
597         self.log.debug(self.path)
598
599         self.settings_file = "{}/settings".format(self.path)
600         if not os.path.isfile(self.settings_file):
601             self.log.error("No such file: {}".format(self.settings_file))
602
603         self.log.debug(self.settings_file)
604         self.config = configparser.ConfigParser()
605         self.config.read(self.settings_file)
606         self.name = self.config["DEFAULT"]["name"]
607         self.machines_string = self.config["DEFAULT"]["machines"]
608         self.networks_string = self.config["DEFAULT"]["networks"]
609
610         self.machines = []
611         for machine in self.machines_string.split(","):
612             self.machines.append(machine.strip())
613
614         self.networks = []
615         for network in self.networks_string.split(","):
616             self.networks.append(network.strip())
617
618         self.log.debug(self.machines)
619         self.log.debug(self.networks)
620
621     def get_networks(self):
622         networks = {}
623         for _network in self.networks:
624             self.log.debug(_network)
625             networks.setdefault(_network, network(os.path.normpath(self.path + "/" + self.config[_network]["xml_file"])))
626         return networks
627
628     def get_machines(self):
629         machines = {}
630         for _machine in self.machines:
631             self.log.debug(_machine)
632             machines.setdefault(_machine, vm(
633                 os.path.normpath(self.path + "/" + self.config[_machine]["xml_file"]),
634                 os.path.normpath(self.path + "/" + self.config[_machine]["snapshot_xml_file"]),
635                 self.config[_machine]["image"],
636                 self.config[_machine]["root_uid"],
637                 self.config[_machine]["username"],
638                 self.config[_machine]["password"]))
639
640         return machines
641
642     @property
643     def machine_names(self):
644         return self.machines
645
646     @property
647     def network_names(self):
648         return self.networks
649
650
651 if __name__ == "__main__":
652     import argparse
653
654     parser = argparse.ArgumentParser()
655
656     parser.add_argument("-d", "--directory", dest="dir")
657
658     args = parser.parse_args()
659
660     currenttest = test(args.dir)
661     currenttest.read_settings()
662     currenttest.virtual_environ_setup()
663     currenttest.load_recipe()
664     try:
665         currenttest.virtual_environ_start()
666         currenttest.run_recipe()
667     except BaseException as e:
668         print(e)
669     finally:
670         currenttest.virtual_environ_stop()
671