Support copying of files into all machines
[nitsi.git] / test.py
1 #!/usr/bin/python3
2
3 import serial
4
5 import re
6 from time import sleep
7 import sys
8
9 import libvirt
10
11 import xml.etree.ElementTree as ET
12
13 import os
14
15 import configparser
16
17 from disk import disk
18
19 class log():
20     def __init__(self, log_level):
21         self.log_level = log_level
22
23     def debug(self, string):
24         if self.log_level >= 4:
25             print("DEBUG: {}".format(string))
26
27     def error(self, string):
28         print("ERROR: {}".format(string))
29
30 class libvirt_con():
31     def __init__(self, uri):
32         self.log = log(4)
33         self.uri = uri
34         self.connection = None
35
36     def get_domain_from_name(self, name):
37         dom = self.con.lookupByName(name)
38
39         if dom == None:
40             raise BaseException
41         return dom
42
43     @property
44     def con(self):
45         if self.connection == None:
46             try:
47                 self.connection = libvirt.open(self.uri)
48             except BaseException as error:
49                 self.log.error("Could not connect to: {}".format(self.uri))
50
51             self.log.debug("Connected to: {}".format(self.uri))
52             return self.connection
53
54         return self.connection
55
56
57 class vm():
58     def __init__(self, vm_xml_file, snapshot_xml_file, image, root_uid, username, password):
59         self.log = log(4)
60         self.con = libvirt_con("qemu:///system")
61         try:
62             with open(vm_xml_file) as fobj:
63                 self.vm_xml = fobj.read()
64         except FileNotFoundError as error:
65             self.log.error("No such file: {}".format(vm_xml_file))
66
67         try:
68             with open(snapshot_xml_file) as fobj:
69                 self.snapshot_xml = fobj.read()
70         except FileNotFoundError as error:
71             self.log.error("No such file: {}".format(snapshot_xml_file))
72
73         self.image = image
74
75         if not os.path.isfile(self.image):
76             self.log.error("No such file: {}".format(self.image))
77
78         self.root_uid = root_uid
79         self.disk = disk(image)
80
81         self.username = username
82         self.password = password
83
84     def define(self):
85         self.dom = self.con.con.defineXML(self.vm_xml)
86         if self.dom == None:
87             self.log.error("Could not define VM")
88             raise BaseException
89
90     def start(self):
91         if self.dom.create() < 0:
92             self.log.error("Could not start VM")
93             raise BaseException
94
95     def shutdown(self):
96         if self.is_running():
97             if self.dom.shutdown() < 0:
98                 self.log.error("Could not shutdown VM")
99                 raise BaseException
100         else:
101             self.log.error("Domain is not running")
102
103     def undefine(self):
104         self.dom.undefine()
105
106     def create_snapshot(self):
107
108         self.snapshot = self.dom.snapshotCreateXML(self.snapshot_xml)
109
110         if not self.snapshot:
111             self.log.error("Could not create snapshot")
112             raise BaseException
113
114     def revert_snapshot(self):
115         self.dom.revertToSnapshot(self.snapshot)
116         self.snapshot.delete()
117
118     def is_running(self):
119
120         state, reason = self.dom.state()
121
122         if state == libvirt.VIR_DOMAIN_RUNNING:
123             return True
124         else:
125             return False
126
127     def get_serial_device(self):
128
129         if not self.is_running():
130             raise BaseException
131
132         xml_root = ET.fromstring(self.dom.XMLDesc(0))
133
134         elem = xml_root.find("./devices/serial/source")
135         return elem.get("path")
136
137     def check_is_booted_up(self):
138         serial_con = connection(self.get_serial_device())
139
140         serial_con.write("\n")
141         # This will block till the domain is booted up
142         serial_con.read(1)
143
144         #serial_con.close()
145
146     def login(self):
147         try:
148             self.serial_con = connection(self.get_serial_device(), username=self.username)
149             self.serial_con.login(self.password)
150         except BaseException as e:
151             self.log.error("Could not connect to the domain via serial console")
152
153     def cmd(self, cmd):
154         return self.serial_con.command(cmd)
155
156     def copy_in(self, fr, to):
157         try:
158             self.disk.mount(self.root_uid, "/")
159             self.disk.copy_in(fr, to)
160         except BaseException as e:
161             self.log.error(e)
162         finally:
163             self.disk.umount("/")
164             self.disk.close()
165
166 class connection():
167     def __init__(self, device, username=None):
168         self.buffer = b""
169         self.back_at_prompt_pattern =  None
170         self.username = username
171         self.log = log(1)
172         self.con = serial.Serial(device)
173
174     def read(self, size=1):
175         if len(self.buffer) >= size:
176             # throw away first size bytes in buffer
177             data =  self.buffer[:size]
178             # Set the buffer to the non used bytes
179             self.buffer = self.buffer[size:]
180             return data
181         else:
182             data = self.buffer
183             # Set the size to the value we have to read now
184             size = size - len(self.buffer)
185             # Set the buffer empty
186             self.buffer = b""
187             return data + self.con.read(size)
188
189     def peek(self, size=1):
190         if len(self.buffer) <= size:
191             self.buffer += self.con.read(size=size - len(self.buffer))
192
193         return self.buffer[:size]
194
195     def readline(self):
196         self.log.debug(self.buffer)
197         self.buffer = self.buffer + self.con.read(self.con.in_waiting)
198         if b"\n" in self.buffer:
199             size = self.buffer.index(b"\n") + 1
200             self.log.debug("We have a whole line in the buffer")
201             self.log.debug(self.buffer)
202             self.log.debug("We split at {}".format(size))
203             data = self.buffer[:size]
204             self.buffer = self.buffer[size:]
205             self.log.debug(data)
206             self.log.debug(self.buffer)
207             return data
208
209         data = self.buffer
210         self.buffer = b""
211         return data + self.con.readline()
212
213     def back_at_prompt(self):
214         data = self.peek()
215         if not data == b"[":
216             return False
217
218         # We need to use self.in_waiting because with self.con.in_waiting we get
219         # not the complete string
220         size = len(self.buffer) + self.in_waiting
221         data = self.peek(size)
222
223
224         if self.back_at_prompt_pattern == None:
225             #self.back_at_prompt_pattern = r"^\[{}@.+\]#".format(self.username)
226             self.back_at_prompt_pattern = re.compile(r"^\[{}@.+\]#".format(self.username), re.MULTILINE)
227
228         if self.back_at_prompt_pattern.search(data.decode()):
229             return True
230         else:
231             return False
232
233     def log_console_line(self, line):
234         self.log.debug("Get in function log_console_line()")
235         sys.stdout.write(line)
236
237     @property
238     def in_waiting(self):
239         in_waiting_before = 0
240         sleep(0.5)
241
242         while in_waiting_before != self.con.in_waiting:
243             in_waiting_before = self.con.in_waiting
244             sleep(0.5)
245
246         return self.con.in_waiting
247
248     def line_in_buffer(self):
249         if b"\n" in self.buffer:
250             return True
251
252         return False
253
254     def readline2(self, pattern=None):
255         string = ""
256         string2 = b""
257         if pattern:
258             pattern = re.compile(pattern)
259
260         while 1:
261             char = self.con.read(1)
262             string = string + char.decode("utf-8")
263             string2 = string2 + char
264             #print(char)
265             print(char.decode("utf-8"), end="")
266
267             #print(string2)
268             if pattern and pattern.match(string):
269                #print("get here1")
270                #print(string2)
271                return {"string" : string, "return-code" : 1}
272
273             if char == b"\n":
274                 #print(char)
275                 #print(string2)
276                 #print("get here2")
277                 return {"return-code" : 0}
278
279     def check_logged_in(self, username):
280         pattern = "^\[" + username + "@.+\]#"
281         data = self.readline(pattern=pattern)
282         if data["return-code"] == 1:
283                 print("We are logged in")
284                 return True
285         else:
286             print("We are  not logged in")
287             return False
288
289     def print_lines_in_buffer(self):
290         while True:
291             self.log.debug("Fill buffer ...")
292             self.peek(len(self.buffer) + self.in_waiting)
293             self.log.debug("Current buffer length: {}".format(len(self.buffer)))
294             if self.line_in_buffer() == True:
295                 while self.line_in_buffer() == True:
296                     data = self.readline()
297                     self.log_console_line(data.decode())
298             else:
299                 self.log.debug("We have printed all lines in the buffer")
300                 break
301
302     def login(self, password):
303         if self.username == None:
304             self.log.error("Username cannot be blank")
305             return False
306
307         self.print_lines_in_buffer()
308
309         # Hit enter to see what we get
310         self.con.write(b'\n')
311         # We get two new lines \r\n ?
312         data = self.readline()
313         self.log_console_line(data.decode())
314
315         self.print_lines_in_buffer()
316
317         if self.back_at_prompt():
318             self.log.debug("We are already logged in.")
319             return True
320
321         # Read all line till we get login:
322         while 1:
323             # We need to use self.in_waiting because with self.con.in_waiting we get
324             # not the complete string
325             size = len(self.buffer) + self.in_waiting
326             data = self.peek(size)
327
328             pattern = r"^.*login: "
329             pattern = re.compile(pattern)
330
331             if pattern.search(data.decode()):
332                 break
333             else:
334                 self.log.debug("The pattern does not match")
335                 self.log_console_line(self.readline().decode())
336
337         # We can login
338         string = "{}\n".format(self.username)
339         self.con.write(string.encode())
340         self.con.flush()
341         # read the login out of the buffer
342         data = self.readline()
343         self.log.debug("This is the login:{}".format(data))
344         self.log_console_line(data.decode())
345
346         # We need to wait her till we get the full string "Password:"
347         #This is useless but self.in_waiting will wait the correct amount of time
348         size = self.in_waiting
349
350         string = "{}\n".format(password)
351         self.con.write(string.encode())
352         self.con.flush()
353         # Print the 'Password:' line
354         data = self.readline()
355         self.log_console_line(data.decode())
356
357         while not self.back_at_prompt():
358             # This will fail if the login failed so we need to look for the failed keyword
359             data = self.readline()
360             self.log_console_line(data.decode())
361
362         return True
363
364     def write(self, string):
365         self.log.debug(string)
366         self.con.write(string.encode())
367         self.con.flush()
368
369     def command(self, command):
370         self.write("{}; echo \"END: $?\"\n".format(command))
371
372         # We need to read out the prompt for this command first
373         # If we do not do this we will break the loop immediately
374         # because the prompt for this command is still in the buffer
375         data = self.readline()
376         self.log_console_line(data.decode())
377
378         while not self.back_at_prompt():
379             data = self.readline()
380             self.log_console_line(data.decode())
381
382         # We saved our exit code in data (the last line)
383         self.log.debug(data.decode())
384         data = data.decode().replace("END: ", "")
385         self.log.debug(data)
386         self.log.debug(data.strip())
387         return data.strip()
388
389
390 # A class which define and undefine a virtual network based on an xml file
391 class network():
392     def __init__(self, network_xml_file):
393         self.log = log(4)
394         self.con = libvirt_con("qemu:///system")
395         try:
396             with open(network_xml_file) as fobj:
397                 self.network_xml = fobj.read()
398         except FileNotFoundError as error:
399             self.log.error("No such file: {}".format(vm_xml_file))
400
401     def define(self):
402         self.network = self.con.con.networkDefineXML(self.network_xml)
403
404         if network == None:
405             self.log.error("Failed to define virtual network")
406
407     def start(self):
408         self.network.create()
409
410     def undefine(self):
411         self.network.destroy()
412
413
414
415 class RecipeExeption(Exception):
416     pass
417
418
419
420 # Should read the test, check if the syntax are valid
421 # and return tuples with the ( host, command ) structure
422 class recipe():
423     def __init__(self, path, circle=[]):
424         self.log = log(4)
425         self.recipe_file = path
426         self.path = os.path.dirname(self.recipe_file)
427         self.log.debug("Path of recipe is: {}".format(self.recipe_file))
428         self._recipe = None
429         self._machines = None
430
431         self.in_recursion = True
432         if len(circle) == 0:
433             self.in_recursion = False
434
435         self.circle = circle
436         self.log.debug(circle)
437         self.log.debug(self.circle)
438
439         if not os.path.isfile(self.recipe_file):
440             self.log.error("No such file: {}".format(self.recipe_file))
441
442         try:
443             with open(self.recipe_file) as fobj:
444                 self.raw_recipe = fobj.readlines()
445         except FileNotFoundError as error:
446             self.log.error("No such file: {}".format(vm_xml_file))
447
448     @property
449     def recipe(self):
450         if not self._recipe:
451             self.parse()
452
453         return self._recipe
454
455     @property
456     def machines(self):
457         if not self._machines:
458             self._machines = []
459             for line in self._recipe:
460                 if line[0] != "all" and line[0] not in self._machines:
461                     self._machines.append(line[0])
462
463         return self._machines
464
465     def parse(self):
466         self._recipe = []
467         i = 1
468         for line in self.raw_recipe:
469             raw_line = line.split(":")
470             if len(raw_line) < 2:
471                 self.log.error("Error parsing the recipe in line {}".format(i))
472                 raise RecipeExeption
473             cmd = raw_line[1].strip()
474             raw_line = raw_line[0].strip().split(" ")
475             if len(raw_line) == 0:
476                 self.log.error("Failed to parse the recipe in line {}".format(i))
477                 raise RecipeExeption
478
479             if raw_line[0].strip() == "":
480                     self.log.error("Failed to parse the recipe in line {}".format(i))
481                     raise RecipeExeption
482
483             machine = raw_line[0].strip()
484
485             if len(raw_line) == 2:
486                 extra = raw_line[1].strip()
487             else:
488                 extra = ""
489
490             # We could get a machine here or a include statement
491             if machine == "include":
492                 path = cmd.strip()
493                 path = os.path.normpath(self.path + "/" + path)
494                 path = path + "/recipe"
495                 if path in self.circle:
496                     self.log.error("Detect import loop!")
497                     raise RecipeExeption
498                 self.circle.append(path)
499                 recipe_to_include = recipe(path, circle=self.circle)
500
501             if machine == "include":
502                 self._recipe.extend(recipe_to_include.recipe)
503             else:
504                 # Support also something like 'alice,bob: echo'
505                 machines = machine.split(",")
506                 for machine in machines:
507                     self._recipe.append((machine.strip(), extra.strip(), cmd.strip()))
508             i = i + 1
509
510             if not self.in_recursion:
511                 tmp_recipe = []
512                 for line in self._recipe:
513                     if line[0] != "all":
514                         tmp_recipe.append(line)
515                     else:
516                         for machine in self.machines:
517                             tmp_recipe.append((machine.strip(), line[1], line[2]))
518
519                 self._recipe = tmp_recipe
520
521
522
523 class test():
524     def __init__(self, path):
525         self.log = log(4)
526         try:
527             self.path = os.path.abspath(path)
528         except BaseException as e:
529             self.log.error("Could not get absolute path")
530
531         self.log.debug(self.path)
532
533         self.settings_file = "{}/settings".format(self.path)
534         if not os.path.isfile(self.settings_file):
535             self.log.error("No such file: {}".format(self.settings_file))
536
537         self.recipe_file = "{}/recipe".format(self.path)
538         if not os.path.isfile(self.recipe_file):
539             self.log.error("No such file: {}".format(self.recipe_file))
540
541     def read_settings(self):
542         self.config = configparser.ConfigParser()
543         self.config.read(self.settings_file)
544         self.name = self.config["DEFAULT"]["Name"]
545         self.description = self.config["DEFAULT"]["Description"]
546         self.copy_to = self.config["DEFAULT"]["Copy_to"]
547         self.copy_from = self.config["DEFAULT"]["Copy_from"]
548         self.copy_from = self.copy_from.split(",")
549
550         tmp = []
551         for file in self.copy_from:
552             file = file.strip()
553             file = os.path.normpath(self.path + "/" + file)
554             tmp.append(file)
555
556         self.copy_from = tmp
557
558         self.virtual_environ_name = self.config["VIRTUAL_ENVIRONMENT"]["Name"]
559         self.virtual_environ_path = self.config["VIRTUAL_ENVIRONMENT"]["Path"]
560         self.virtual_environ_path = os.path.normpath(self.path + "/" + self.virtual_environ_path)
561
562     def virtual_environ_setup(self):
563         self.virtual_environ = virtual_environ(self.virtual_environ_path)
564
565         self.virtual_networks = self.virtual_environ.get_networks()
566
567         self.virtual_machines = self.virtual_environ.get_machines()
568
569     def virtual_environ_start(self):
570         for name in self.virtual_environ.network_names:
571             self.virtual_networks[name].define()
572             self.virtual_networks[name].start()
573
574         for name in self.virtual_environ.machine_names:
575             self.virtual_machines[name].define()
576             self.virtual_machines[name].create_snapshot()
577             self.virtual_machines[name].copy_in(self.copy_from, self.copy_to)
578             self.virtual_machines[name].start()
579
580         self.log.debug("Try to login on all machines")
581         for name in self.virtual_environ.machine_names:
582             self.virtual_machines[name].login()
583
584     def load_recipe(self):
585         try:
586             self.recipe = recipe(self.recipe_file)
587         except BaseException:
588             self.log.error("Failed to load recipe")
589             exit(1)
590
591     def run_recipe(self):
592         for line in self.recipe.recipe:
593             return_value = self.virtual_machines[line[0]].cmd(line[2])
594             self.log.debug("Return value is: {}".format(return_value))
595             if return_value != "0" and line[1] == "":
596                 self.log.error("Failed to execute command '{}' on {}, return code: {}".format(line[2],line[0], return_value))
597                 return False
598             elif return_value == "0" and line[1] == "!":
599                 self.log.error("Succeded to execute command '{}' on {}, return code: {}".format(line[2],line[0],return_value))
600                 return False
601             else:
602                 self.log.debug("Command '{}' on {} returned with: {}".format(line[2],line[0],return_value))
603
604     def virtual_environ_stop(self):
605         for name in self.virtual_environ.machine_names:
606             self.virtual_machines[name].shutdown()
607             self.virtual_machines[name].revert_snapshot()
608             self.virtual_machines[name].undefine()
609
610         for name in self.virtual_environ.network_names:
611             self.virtual_networks[name].undefine()
612
613
614 # Should return all vms and networks in a list
615 # and should provide the path to the necessary xml files
616 class virtual_environ():
617     def __init__(self, path):
618         self.log = log(4)
619         try:
620             self.path = os.path.abspath(path)
621         except BaseException as e:
622             self.log.error("Could not get absolute path")
623
624         self.log.debug(self.path)
625
626         self.settings_file = "{}/settings".format(self.path)
627         if not os.path.isfile(self.settings_file):
628             self.log.error("No such file: {}".format(self.settings_file))
629
630         self.log.debug(self.settings_file)
631         self.config = configparser.ConfigParser()
632         self.config.read(self.settings_file)
633         self.name = self.config["DEFAULT"]["name"]
634         self.machines_string = self.config["DEFAULT"]["machines"]
635         self.networks_string = self.config["DEFAULT"]["networks"]
636
637         self.machines = []
638         for machine in self.machines_string.split(","):
639             self.machines.append(machine.strip())
640
641         self.networks = []
642         for network in self.networks_string.split(","):
643             self.networks.append(network.strip())
644
645         self.log.debug(self.machines)
646         self.log.debug(self.networks)
647
648     def get_networks(self):
649         networks = {}
650         for _network in self.networks:
651             self.log.debug(_network)
652             networks.setdefault(_network, network(os.path.normpath(self.path + "/" + self.config[_network]["xml_file"])))
653         return networks
654
655     def get_machines(self):
656         machines = {}
657         for _machine in self.machines:
658             self.log.debug(_machine)
659             machines.setdefault(_machine, vm(
660                 os.path.normpath(self.path + "/" + self.config[_machine]["xml_file"]),
661                 os.path.normpath(self.path + "/" + self.config[_machine]["snapshot_xml_file"]),
662                 self.config[_machine]["image"],
663                 self.config[_machine]["root_uid"],
664                 self.config[_machine]["username"],
665                 self.config[_machine]["password"]))
666
667         return machines
668
669     @property
670     def machine_names(self):
671         return self.machines
672
673     @property
674     def network_names(self):
675         return self.networks
676
677
678 if __name__ == "__main__":
679     import argparse
680
681     parser = argparse.ArgumentParser()
682
683     parser.add_argument("-d", "--directory", dest="dir")
684
685     args = parser.parse_args()
686
687     currenttest = test(args.dir)
688     currenttest.read_settings()
689     currenttest.virtual_environ_setup()
690     currenttest.load_recipe()
691     try:
692         currenttest.virtual_environ_start()
693         currenttest.run_recipe()
694     except BaseException as e:
695         print(e)
696     finally:
697         currenttest.virtual_environ_stop()
698