]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
Complete broadcast support (both raw and via port mapper CALLIT)
authorGuido van Rossum <guido@python.org>
Mon, 21 Dec 1992 14:32:06 +0000 (14:32 +0000)
committerGuido van Rossum <guido@python.org>
Mon, 21 Dec 1992 14:32:06 +0000 (14:32 +0000)
Demo/rpc/rpc.py

index d1c2c5e599ec4898f6852c0a72a10221d27af199..00397dd852e4edfa663bdd0d78d6431396e32e76 100644 (file)
@@ -1,4 +1,4 @@
-# Implement (a subset of) Sun RPC, version 2 -- RFC1057.
+# Sun RPC version 2 -- RFC1057.
 
 # XXX There should be separate exceptions for the various reasons why
 # XXX an RPC can fail, rather than using RuntimeError for everything
@@ -177,8 +177,8 @@ class Client:
                self.port = port
                self.makesocket() # Assigns to self.sock
                self.bindsocket()
-               self.sock.connect((host, port))
-               self.lastxid = 0
+               self.connsocket()
+               self.lastxid = 0 # XXX should be more random?
                self.addpackers()
                self.cred = None
                self.verf = None
@@ -191,6 +191,10 @@ class Client:
                # This MUST be overridden
                raise RuntimeError, 'makesocket not defined'
 
+       def connsocket(self):
+               # Override this if you don't want/need a connection
+               self.sock.connect((self.host, self.port))
+
        def bindsocket(self):
                # Override this to bind to a different port (e.g. reserved)
                self.sock.bind(('', 0))
@@ -200,6 +204,21 @@ class Client:
                self.packer = Packer().init()
                self.unpacker = Unpacker().init('')
 
+       def make_call(self, proc, args, pack_func, unpack_func):
+               # Don't normally override this (but see Broadcast)
+               if pack_func is None and args is not None:
+                       raise TypeError, 'non-null args with null pack_func'
+               self.start_call(proc)
+               if pack_func:
+                       pack_func(args)
+               self.do_call()
+               if unpack_func:
+                       result = unpack_func()
+               else:
+                       result = None
+               self.unpacker.done()
+               return result
+
        def start_call(self, proc):
                # Don't override this
                self.lastxid = xid = self.lastxid + 1
@@ -209,14 +228,10 @@ class Client:
                p.reset()
                p.pack_callheader(xid, self.prog, self.vers, proc, cred, verf)
 
-       def do_call(self, *rest):
+       def do_call(self):
                # This MUST be overridden
                raise RuntimeError, 'do_call not defined'
 
-       def end_call(self):
-               # Don't override this
-               self.unpacker.done()
-
        def mkcred(self):
                # Override this to use more powerful credentials
                if self.cred == None:
@@ -230,9 +245,7 @@ class Client:
                return self.verf
 
        def Null(self):                 # Procedure 0 is always like this
-               self.start_call(0)
-               self.do_call(0)
-               self.end_call()
+               return self.make_call(0, None, None, None)
 
 
 # Record-Marking standard support
@@ -293,23 +306,14 @@ def bindresvport(sock, host):
        raise RuntimeError, 'can\'t assign reserved port'
 
 
-# Raw TCP-based client
+# Client using TCP to a specific port
 
 class RawTCPClient(Client):
 
        def makesocket(self):
                self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
 
-       def start_call(self, proc):
-               self.lastxid = xid = self.lastxid + 1
-               cred = self.mkcred()
-               verf = self.mkverf()
-               p = self.packer
-               p.reset()
-               p.pack_callheader(xid, self.prog, self.vers, proc, cred, verf)
-
-       def do_call(self, *rest):
-               # rest is used for UDP buffer size; ignored for TCP
+       def do_call(self):
                call = self.packer.get_buf()
                sendrecord(self.sock, call)
                reply = recvrecord(self.sock)
@@ -321,41 +325,25 @@ class RawTCPClient(Client):
                        raise RuntimeError, 'wrong xid in reply ' + `xid` + \
                                ' instead of ' + `self.lastxid`
 
-       def end_call(self):
-               self.unpacker.done()
-
 
-# Raw UDP-based client
+# Client using UDP to a specific port
 
 class RawUDPClient(Client):
 
        def makesocket(self):
                self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
 
-       def start_call(self, proc):
-               self.lastxid = xid = self.lastxid + 1
-               cred = self.mkcred()
-               verf = self.mkverf()
-               p = self.packer
-               p.reset()
-               p.pack_callheader(xid, self.prog, self.vers, proc, cred, verf)
-
-       def do_call(self, *rest):
+       def do_call(self):
+               call = self.packer.get_buf()
+               self.sock.send(call)
                try:
                        from select import select
                except ImportError:
                        print 'WARNING: select not found, RPC may hang'
                        select = None
-               if len(rest) == 0:
-                       bufsize = 8192
-               elif len(rest) > 1:
-                       raise TypeError, 'too many args'
-               else:
-                       bufsize = rest[0] + 512
-               call = self.packer.get_buf()
+               BUFSIZE = 8192 # Max UDP buffer size
                timeout = 1
                count = 5
-               self.sock.send(call)
                while 1:
                        r, w, x = [self.sock], [], []
                        if select:
@@ -367,7 +355,7 @@ class RawUDPClient(Client):
 ##                             print 'RESEND', timeout, count
                                self.sock.send(call)
                                continue
-                       reply = self.sock.recv(bufsize)
+                       reply = self.sock.recv(BUFSIZE)
                        u = self.unpacker
                        u.reset(reply)
                        xid, verf = u.unpack_replyheader()
@@ -376,13 +364,70 @@ class RawUDPClient(Client):
                                continue
                        break
 
-       def end_call(self):
-               self.unpacker.done()
 
+# Client using UDP broadcast to a specific port
 
-# Port mapper interface
+class RawBroadcastUDPClient(RawUDPClient):
+
+       def init(self, bcastaddr, prog, vers, port):
+               self = RawUDPClient.init(self, bcastaddr, prog, vers, port)
+               self.reply_handler = None
+               self.timeout = 30
+               return self
+
+       def connsocket(self):
+               # Don't connect -- use sendto
+               self.sock.allowbroadcast(1)
+
+       def set_reply_handler(self, reply_handler):
+               self.reply_handler = reply_handler
 
-# XXX CALLIT is not implemented
+       def set_timeout(self, timeout):
+               self.timeout = timeout # Use None for infinite timeout
+
+       def make_call(self, proc, args, pack_func, unpack_func):
+               if pack_func is None and args is not None:
+                       raise TypeError, 'non-null args with null pack_func'
+               self.start_call(proc)
+               if pack_func:
+                       pack_func(args)
+               call = self.packer.get_buf()
+               self.sock.sendto(call, (self.host, self.port))
+               try:
+                       from select import select
+               except ImportError:
+                       print 'WARNING: select not found, broadcast will hang'
+                       select = None
+               BUFSIZE = 8192 # Max UDP buffer size (for reply)
+               replies = []
+               if unpack_func is None:
+                       def dummy(): pass
+                       unpack_func = dummy
+               while 1:
+                       r, w, x = [self.sock], [], []
+                       if select:
+                               if self.timeout is None:
+                                       r, w, x = select(r, w, x)
+                               else:
+                                       r, w, x = select(r, w, x, self.timeout)
+                       if self.sock not in r:
+                               break
+                       reply, fromaddr = self.sock.recvfrom(BUFSIZE)
+                       u = self.unpacker
+                       u.reset(reply)
+                       xid, verf = u.unpack_replyheader()
+                       if xid <> self.lastxid:
+##                             print 'BAD xid'
+                               continue
+                       reply = unpack_func()
+                       self.unpacker.done()
+                       replies.append((reply, fromaddr))
+                       if self.reply_handler:
+                               self.reply_handler(reply, fromaddr)
+               return replies
+
+
+# Port mapper interface
 
 # Program number, version and (fixed!) port number
 PMAP_PROG = 100000
@@ -421,6 +466,13 @@ class PortMapperPacker(Packer):
        def pack_pmaplist(self, list):
                self.pack_list(list, self.pack_mapping)
 
+       def pack_call_args(self, ca):
+               prog, vers, proc, args = ca
+               self.pack_uint(prog)
+               self.pack_uint(vers)
+               self.pack_uint(proc)
+               self.pack_opaque(args)
+
 
 class PortMapperUnpacker(Unpacker):
 
@@ -434,6 +486,11 @@ class PortMapperUnpacker(Unpacker):
        def unpack_pmaplist(self):
                return self.unpack_list(self.unpack_mapping)
 
+       def unpack_call_result(self):
+               port = self.unpack_uint()
+               res = self.unpack_opaque()
+               return port, res
+
 
 class PartialPortMapperClient:
 
@@ -442,35 +499,29 @@ class PartialPortMapperClient:
                self.unpacker = PortMapperUnpacker().init('')
 
        def Set(self, mapping):
-               self.start_call(PMAPPROC_SET)
-               self.packer.pack_mapping(mapping)
-               self.do_call()
-               res = self.unpacker.unpack_uint()
-               self.end_call()
-               return res
+               return self.make_call(PMAPPROC_SET, mapping, \
+                       self.packer.pack_mapping, \
+                       self.unpacker.unpack_uint)
 
        def Unset(self, mapping):
-               self.start_call(PMAPPROC_UNSET)
-               self.packer.pack_mapping(mapping)
-               self.do_call()
-               res = self.unpacker.unpack_uint()
-               self.end_call()
-               return res
+               return self.make_call(PMAPPROC_UNSET, mapping, \
+                       self.packer.pack_mapping, \
+                       self.unpacker.unpack_uint)
 
        def Getport(self, mapping):
-               self.start_call(PMAPPROC_GETPORT)
-               self.packer.pack_mapping(mapping)
-               self.do_call(4)
-               port = self.unpacker.unpack_uint()
-               self.end_call()
-               return port
+               return self.make_call(PMAPPROC_GETPORT, mapping, \
+                       self.packer.pack_mapping, \
+                       self.unpacker.unpack_uint)
 
        def Dump(self):
-               self.start_call(PMAPPROC_DUMP)
-               self.do_call(8192-512)
-               list = self.unpacker.unpack_pmaplist()
-               self.end_call()
-               return list
+               return self.make_call(PMAPPROC_DUMP, None, \
+                       None, \
+                       self.unpacker.unpack_pmaplist)
+
+       def Callit(self, ca):
+               return self.make_call(PMAPPROC_CALLIT, ca, \
+                       self.packer.pack_call_args, \
+                       self.unpacker.unpack_call_result)
 
 
 class TCPPortMapperClient(PartialPortMapperClient, RawTCPClient):
@@ -487,6 +538,16 @@ class UDPPortMapperClient(PartialPortMapperClient, RawUDPClient):
                        host, PMAP_PROG, PMAP_VERS, PMAP_PORT)
 
 
+class BroadcastUDPPortMapperClient(PartialPortMapperClient, \
+                                  RawBroadcastUDPClient):
+
+       def init(self, bcastaddr):
+               return RawBroadcastUDPClient.init(self, \
+                       bcastaddr, PMAP_PROG, PMAP_VERS, PMAP_PORT)
+
+
+# Generic clients that find their server through the Port mapper
+
 class TCPClient(RawTCPClient):
 
        def init(self, host, prog, vers):
@@ -509,6 +570,51 @@ class UDPClient(RawUDPClient):
                return RawUDPClient.init(self, host, prog, vers, port)
 
 
+class BroadcastUDPClient(Client):
+
+       def init(self, bcastaddr, prog, vers):
+               self.pmap = BroadcastUDPPortMapperClient().init(bcastaddr)
+               self.pmap.set_reply_handler(self.my_reply_handler)
+               self.prog = prog
+               self.vers = vers
+               self.user_reply_handler = None
+               self.addpackers()
+               return self
+
+       def close(self):
+               self.pmap.close()
+
+       def set_reply_handler(self, reply_handler):
+               self.user_reply_handler = reply_handler
+
+       def set_timeout(self, timeout):
+               self.pmap.set_timeout(timeout)
+
+       def my_reply_handler(self, reply, fromaddr):
+               port, res = reply
+               self.unpacker.reset(res)
+               result = self.unpack_func()
+               self.unpacker.done()
+               self.replies.append((result, fromaddr))
+               if self.user_reply_handler is not None:
+                       self.user_reply_handler(result, fromaddr)
+
+       def make_call(self, proc, args, pack_func, unpack_func):
+               self.packer.reset()
+               if pack_func:
+                       pack_func(args)
+               if unpack_func is None:
+                       def dummy(): pass
+                       self.unpack_func = dummy
+               else:
+                       self.unpack_func = unpack_func
+               self.replies = []
+               packed_args = self.packer.get_buf()
+               dummy_replies = self.pmap.Callit( \
+                       (self.prog, self.vers, proc, packed_args))
+               return self.replies
+
+
 # Server classes
 
 # These are not symmetric to the Client classes
@@ -657,14 +763,9 @@ class UDPServer(Server):
 # Simple test program -- dump local portmapper status
 
 def test():
-       import T
-       T.TSTART()
        pmap = UDPPortMapperClient().init('')
-       T.TSTOP()
        pmap.Null()
-       T.TSTOP()
        list = pmap.Dump()
-       T.TSTOP()
        list.sort()
        for prog, vers, prot, port in list:
                print prog, vers,
@@ -674,7 +775,24 @@ def test():
                print port
 
 
-# Server and client test program.
+# Test program for broadcast operation -- dump everybody's portmapper status
+
+def testbcast():
+       import sys
+       if sys.argv[1:]:
+               bcastaddr = sys.argv[1]
+       else:
+               bcastaddr = '<broadcast>'
+       def rh(reply, fromaddr):
+               host, port = fromaddr
+               print host + '\t' + `reply`
+       pmap = BroadcastUDPPortMapperClient().init(bcastaddr)
+       pmap.set_reply_handler(rh)
+       pmap.set_timeout(5)
+       replies = pmap.Getport((100002, 1, IPPROTO_UDP, 0))
+
+
+# Test program for server, with corresponding client
 # On machine A: python -c 'import rpc; rpc.testsvr()'
 # On machine B: python -c 'import rpc; rpc.testclt()' A
 # (A may be == B)
@@ -709,12 +827,9 @@ def testclt():
        # Client for above server
        class C(UDPClient):
                def call_1(self, arg):
-                       self.start_call(1)
-                       self.packer.pack_string(arg)
-                       self.do_call()
-                       reply = self.unpacker.unpack_string()
-                       self.end_call()
-                       return reply
+                       return self.make_call(1, arg, \
+                               self.packer.pack_string, \
+                               self.unpacker.unpack_string)
        c = C().init(host, 0x20000000, 1)
        print 'making call...'
        reply = c.call_1('hello, world, ')