Changeset 6156

Show
Ignore:
Timestamp:
01/29/13 08:34:24 (6 years ago)
Author:
sportzer
Message:

Fixes a number of bugs in various shims

Location:
seattle/branches/repy_v2/shims
Files:
8 modified

Legend:

Unmodified
Added
Removed
  • seattle/branches/repy_v2/shims/compressionshim.repy

    r4694 r6156  
    2222 
    2323 
    24 _compression_shim_global_lock = createlock() 
    25  
    26 def _compression_shim_atomic_operation(target_func): 
     24 
     25 
     26def _compression_socket_atomic_operation(target_func): 
    2727  """ 
    2828  Function decorator. The target function is invoked atomically. 
    2929 
    3030  """ 
    31   def wrapper(*args, **kwargs): 
     31  def wrapper(self, *args, **kwargs): 
    3232    try: 
    33       _compression_shim_global_lock.acquire(True) 
    34       return target_func(*args, **kwargs) 
     33      self._atomic_lock.acquire(True) 
     34      return target_func(self, *args, **kwargs) 
    3535    finally: 
    36       _compression_shim_global_lock.release() 
     36      self._atomic_lock.release() 
    3737 
    3838  return wrapper 
     39 
     40 
     41 
     42 
     43class CompressionSocket(): 
     44 
     45  def __init__(self, socket, shim): 
     46    self.socket = socket 
     47    self.shim = shim 
     48 
     49    # Initialize all the necessary dictionaries and sets. 
     50    self._send_buf = '' 
     51    self._recv_buf = '' 
     52    self._result_buf = '' 
     53 
     54    self._full_lock = createlock() 
     55    self._empty_lock = createlock() 
     56    self._mutex_lock = createlock() 
     57    self._atomic_lock = createlock() 
     58 
     59    self.closed_local = False 
     60    self.closed_remote = False 
     61    self.close_lock = createlock() 
     62 
     63    # FF: The "full" lock is initialized to locked, so that the sending thread   
     64    # will sleep until there is data to send (the lock will be released by the  
     65    # compression thread after it places data in the send buffer) 
     66    self._full_lock.acquire(True) 
     67 
     68    # FF: Create the sending thread for this connection 
     69    createthread(self._send_pending_data) 
     70 
     71 
     72  def close(self): 
     73    self.close_lock.acquire(True) 
     74    result = not self.closed_local 
     75    self.closed_local = True 
     76    self.close_lock.release() 
     77 
     78    try: 
     79      self._full_lock.release() 
     80    except Exception: 
     81      pass 
     82 
     83    return result 
     84 
     85 
     86  @_compression_socket_atomic_operation 
     87  def send(self, msg): 
     88    """  
     89    <Purpose> 
     90      FF: Compression thread. Breaks up the original stream into blocks of  
     91      _SEND_BLOCK_SIZE and compresses individual blocks. Prepends the length to  
     92      each block as the header, adds a 'T' to the end of each block (so that I  
     93      don't have the change the recv part) and adds these blocks to the send 
     94      buffer sequentially if the send buffer is empty.  
     95 
     96    <Arguments> 
     97      Same as repy v2 socket API. 
     98 
     99    <Exceptions> 
     100      Same as repy v2 socket API. 
     101 
     102    <Side Effects> 
     103      Same as repy v2 socket API. 
     104 
     105    <Returns> 
     106      Same as repy v2 socket API. 
     107 
     108    """ 
     109 
     110    if self.closed_local: 
     111      raise SocketClosedLocal("Socket closed locally!") 
     112 
     113    if self.closed_remote: 
     114      raise SocketClosedRemote("Socket closed remotely!") 
     115 
     116    # How much of the original uncompressed data has been sent successfully. 
     117    total_original_bytes_sent = 0 
     118 
     119    # Keep sending the supplied message until no more data to send. 
     120    while msg: 
     121 
     122      # Create a compressed block of data out of the original message. 
     123      uncompressed_block = msg[0 : self.shim._SEND_BLOCK_SIZE] 
     124      msg = msg[len(uncompressed_block) : ] 
     125      block_body = zlib.compress(uncompressed_block) 
     126 
     127      # Set the boolean tag as 'T'. 
     128      block_body += 'T' 
     129 
     130      # Append header information to indicate the length of the block. 
     131      block_header = str(len(block_body)) + ',' 
     132      block_data = block_header + block_body 
     133 
     134      # If the send buffer is empty, place this block in the send buffer 
     135      empty = self._empty_lock.acquire(False) 
     136      mutex = self._mutex_lock.acquire(False) 
     137      if mutex and empty: 
     138        self._send_buf = block_data 
     139        self._mutex_lock.release() 
     140        self._full_lock.release() 
     141        total_original_bytes_sent += len(uncompressed_block) 
     142      else: 
     143        # Release any lock we may have successfully acquired: 
     144        if empty: 
     145          self._empty_lock.release() 
     146        if mutex: 
     147          self._mutex_lock.release() 
     148        break 
     149 
     150    # If we have not sent any data, the system's send buffer must be full. 
     151    if total_original_bytes_sent == 0 and len(msg) > 0: 
     152      raise SocketWouldBlockError 
     153 
     154    return total_original_bytes_sent 
     155 
     156 
     157  @_compression_socket_atomic_operation 
     158  def recv(self, bytes): 
     159    """  
     160    <Purpose> 
     161      Receives as much as possible into the receive buffer until the socket 
     162      blocks. 
     163 
     164      Then, from the receive buffer, we reconstruct all the completely received 
     165      blocks. A complete block is a string in the form of "n,msg", where msg is 
     166      the compressed message and n is its length. 
     167 
     168      For each complete block, we decompress the message and add it to the 
     169      result buffer, which stores the original TCP stream. We return up to the 
     170      requested number of bytes from the result buffer. If the result buffer is 
     171      empty, we raise the blocking exception. 
     172 
     173    <Arguments> 
     174      Same as repy v2 socket API. 
     175 
     176    <Exceptions> 
     177      Same as repy v2 socket API. 
     178 
     179    <Side Effects> 
     180      Same as repy v2 socket API. 
     181 
     182    <Returns> 
     183      Same as repy v2 socket API. 
     184 
     185    """ 
     186 
     187    # Get the result buffer out of the dictionary. 
     188    result_buf = self._result_buf 
     189 
     190    # If our buffer already has data, then we just return it. 
     191    if len(result_buf) > bytes: 
     192      requested_data = result_buf[0 : bytes] 
     193      self._result_buf = result_buf[len(requested_data) : ] 
     194 
     195      return requested_data 
     196 
     197    recv_exception = None 
     198 
     199    # Receive as much as possible into the receive buffer, as long as the socket 
     200    # is active (i.e. not closed remotely or locally). 
     201    while True: 
     202 
     203      try: 
     204        self._recv_buf += self.socket.recv(self.shim._RECV_BLOCK_SIZE) 
     205 
     206        # If we have atleast lenght of bytes of compressed data, then most likely we 
     207        # already have bytes length uncompressed data. We do this so the application  
     208        # doesn't have to wait for CompressionShim to block before returning data. 
     209        if len(self._recv_buf) > bytes: 
     210          break 
     211      # No more data to read from socket. 
     212      except SocketWouldBlockError: 
     213        break 
     214 
     215      # If a different exception occur, we save it first. We will raise it later 
     216      # when we run out of data to return (i.e. empty result buffer). The socket 
     217      # is now considered inactive. We remove it from the active socket set. 
     218      except (SocketClosedLocal, SocketClosedRemote), err: 
     219        self.closed_remote = True 
     220        recv_exception = err 
     221        break 
     222 
     223      # end-try 
     224 
     225    # end-while 
     226 
     227    # Reconstruct all the blocks of compressed messages from the raw TCP 
     228    # stream we received in the receive buffer. For each block, decompress 
     229    # it and add it to the result buffer. 
     230    while True: 
     231      compressed_block = self._reconstruct_blocks() 
     232      if compressed_block is None: 
     233        break 
     234      elif len(compressed_block) > 0: 
     235        result_buf += zlib.decompress(compressed_block) 
     236 
     237    # If there is nothing in the result buffer, we have received all the data. 
     238    if result_buf == '': 
     239 
     240      # If we have saved exceptions, we raise them now, as these exceptions 
     241      # occurred at the end of the received stream. 
     242      if recv_exception: 
     243        raise recv_exception 
     244 
     245      # We simply run out of data without any other exceptions. 
     246      else: 
     247        raise SocketWouldBlockError 
     248 
     249    # Remove the portion requested by the application and return it. 
     250    requested_data = result_buf[0 : bytes] 
     251    self._result_buf = result_buf[len(requested_data) : ] 
     252 
     253    return requested_data 
     254 
     255 
     256  def _send_pending_data(self): 
     257    """ 
     258    FF: Sending thread. When there is data in the send buffer (signaled by the "full" 
     259    lock becoming available), it keeps trying to send until the buffer is empty,  
     260    at which point it signals to the compression thread by releasing the "empty" lock. 
     261    """ 
     262 
     263    while True: 
     264      self._full_lock.acquire(True) 
     265      self._mutex_lock.acquire(True) 
     266 
     267      try: 
     268        send_buf = self._send_buf 
     269        bytes_to_send = len(send_buf) 
     270         
     271        # Send everything in the send buffer and remove it from the buffer. 
     272        while send_buf: 
     273          sent_bytes = 0 
     274          try: 
     275            sent_bytes = self.socket.send(send_buf) 
     276          except SocketWouldBlockError, err: 
     277            pass 
     278          except (SocketClosedLocal, SocketClosedRemote), err: 
     279            self.closed_remote = True 
     280            return 
     281          send_buf = send_buf[sent_bytes : ] 
     282 
     283        if self.closed_local: 
     284          self.socket.close() 
     285          return 
     286 
     287      finally: 
     288        # When the send buffer is empty, release the locks and sleep until there 
     289        # is more data to send 
     290        self._mutex_lock.release() 
     291        try: 
     292          self._empty_lock.release() 
     293        except Exception: 
     294          pass 
     295 
     296 
     297  def _reconstruct_blocks(self): 
     298    """ 
     299    Helper method for the socket_recv method. Reconstructs and returns the 
     300    leftmost complete block starting from the head of the receive buffer. If 
     301    there are no more blocks to reconstruct, returns None. Returns an empty 
     302    string if the block is to be discarded. 
     303 
     304    For instance, we may have received the following stream into the 
     305    receive buffer (the '>' sign denotes the beginning of the receive buffer): 
     306     
     307    > 11,HelloWorldT6,pandaT7,googxxF10,micxxx 
     308 
     309    We start from the beginning and parse the header. We read 11 bytes. The last 
     310    byte is a 'T', so we know the block is complete. We have reconstructed 
     311    'HelloWorld' and we move on (because this method is probably used in a while 
     312    loop). Now the buffer becomes: 
     313 
     314    > 6,pandaT7,googxxF10,micxxx 
     315 
     316    Similarly, we are able to reconstruct the block 'panda'. The header 
     317    subsequently looks like: 
     318 
     319    > 7,googxxF10,micxxx 
     320 
     321    We can also reconstruct the new block as 'googxx', but since the tag is an 
     322    'F', we reject and discard the block. Now, the buffer becomes: 
     323 
     324    >10,micxxx 
     325 
     326    We expect that the block has 10 bytes of data, but the buffer ends before 
     327    that. We assume more data is coming in, so we leave the data on the buffer 
     328    and return all the good blocks we have read so far as 'HelloWorld' and 
     329    'panda'. 
     330 
     331    """ 
     332    # Get the receive buffer.  
     333    recv_buf = self._recv_buf 
     334 
     335    # Base case: empty receive buffer. Return no blocks. 
     336    if recv_buf == '': 
     337      return None 
     338 
     339    # Stores the length of the block as a string. 
     340    block_length_str = '' 
     341 
     342    # Position in the receive buffer to be read. 
     343    cur_position = 0 
     344 
     345    # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #  
     346    # Parse the header 
     347    # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #  
     348 
     349    while True: 
     350 
     351      # Attempt to read the header of the block. Read a character at a 
     352      # time until comma or 'F'. 
     353      try: 
     354        char = recv_buf[cur_position] 
     355 
     356      # We still haven't obtained the complete header despite reaching the end 
     357      # of the buffer. Hopefully, we will receive more data into the buffer to 
     358      # form a complete header. For now, there is nothing we can do. We keep all 
     359      # the data in the buffer and exit. 
     360      except IndexError: 
     361        return None 
     362 
     363      cur_position += 1 
     364 
     365      # The correct header should end with a comma. Now that we have 
     366      # successfully read the header, we parse the length. 
     367      if char == ',': 
     368        block_length = int(block_length_str) 
     369        break 
     370 
     371      # The header has an 'F', so a SocketWouldBlockError must have occurred as 
     372      # the header is being sent (e.g. '3F'). We discard this block and retry 
     373      # from a position after the 'F' tag. 
     374      elif char == 'F': 
     375        self._recv_buf = recv_buf[cur_position : ] 
     376        return '' 
     377 
     378      # The character is neither a comma or 'F', so we must be still reading the 
     379      # integers in the header. 
     380      elif char.isdigit(): 
     381        block_length_str += char 
     382 
     383      # There must have been a bug! 
     384      else: 
     385        err_str = 'CompressionShim: Invalid characer at position ' + str(cur_position)  
     386        err_str += ' in recv buffer: ' + str(recv_buf) 
     387        raise ShimInternalError(err_str) 
     388 
     389    # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #  
     390    # Reconstruct block 
     391    # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #  
     392     
     393    # It is possible that the we are still in the middle of receiving the block, 
     394    # and that we only have a part of the block in the receive buffer. 
     395    # Hopefully, we will receive more data into the buffer to form a complete 
     396    # block. For now, there is nothing we can do. We keep all the data in the 
     397    # buffer and exit. 
     398    if cur_position + block_length > len(recv_buf): 
     399      return None 
     400 
     401    # At this point, we know that we have a complete block. Read block_length of 
     402    # bytes from the receive buffer to form a block and remove it from the 
     403    # buffer. 
     404    block_body = recv_buf[cur_position : cur_position + block_length] 
     405    cur_position += block_length 
     406 
     407    # If it is a bad block, we remove it from the buffer, ignore it and retry 
     408    # after the 'F' tag. 
     409    if block_body.endswith('F'): 
     410      self._recv_buf = recv_buf[cur_position : ] 
     411      return '' 
     412 
     413    # It's a good block! Remove it from the buffer, return it (minus the 'T') 
     414    # and keep reading the other blocks recursively. 
     415    elif block_body.endswith('T'): 
     416      self._recv_buf = recv_buf[cur_position : ] 
     417      return block_body[0 : len(block_body) - 1] 
     418 
     419    # The block should not end with anything else! 
     420    else: 
     421      err_str = 'CompressionShim: Invalid block "' + block_body + '" in buffer: ' + recv_buf 
     422      raise ShimInternalError(err_str) 
     423 
    39424 
    40425 
     
    51436 
    52437    """ 
    53  
    54     # A dictionary that maps a socket to its send buffer, which stores any 
    55     # temporary data pending to be transmitted. 
    56     self._send_buf_dict = {} 
    57  
    58     # A dictionary that maps a socket to its receive buffer, which stores 
    59     # the compressed TCP stream. 
    60     self._recv_buf_dict = {} 
    61  
    62     # A dictionary that maps a socket to any exception it raises while we try to 
    63     # receive as much as possible. 
    64     self._recv_exception_dict = {} 
    65  
    66     # A dictionary that maps a socket to the stream of uncompressed data. 
    67     self._result_buf_dict = {} 
    68  
    69     # A set that stores all the active sockets used in socket_send. If a socket 
    70     # raises an exception other than SocketWouldBlockError, it is considered 
    71     # inactive and thus removed from the set. 
    72     self._active_recv_socket_set = set() 
    73  
    74     # FF: A dictionary that maps a socket to the locks used to synchronize the  
    75     # socket's sending thread and compression thread 
    76     self._full_lock = {} 
    77     self._empty_lock = {} 
    78     self._mutex_lock = {} 
    79  
    80438 
    81439    # If optional args is provided, use them as the block sizes. 
     
    94452 
    95453 
    96  
    97454  def copy(self): 
    98455    return CompressionShim(self.shim_context['shim_stack'].copy(), self.shim_context['optional_args']) 
    99  
    100456 
    101457   
     
    115471 
    116472 
    117  
    118473  # ========================================================= 
    119474  # Initializing the proper dicts during new socket creation. 
    120475  # ========================================================= 
    121476 
    122  
    123477  def tcpserversocket_getconnection(self, tcpserversocket): 
    124478    # Call the next layer of tcpserver_getconnection() 
    125479    (remote_ip, remote_port, repy_socket) = self.get_next_shim_layer().tcpserversocket_getconnection(tcpserversocket) 
    126480 
    127     if isinstance(repy_socket, ShimSocket): 
    128       sockobj = repy_socket._socket 
    129     else: 
    130       sockobj = repy_socket 
    131  
    132     # Initialize all the necessary dictionaries and sets. 
    133     self._send_buf_dict[repr(sockobj)] = '' 
    134     self._recv_buf_dict[repr(sockobj)] = '' 
    135     self._recv_exception_dict[repr(sockobj)] = None 
    136     self._result_buf_dict[repr(sockobj)] = '' 
    137     self._active_recv_socket_set.add(repr(sockobj)) 
    138  
    139     # FF: locks for synchronizing threads 
    140     self._full_lock[repr(sockobj)] = createlock() 
    141     self._empty_lock[repr(sockobj)] = createlock() 
    142     self._mutex_lock[repr(sockobj)] = createlock() 
    143  
    144  
    145     # FF: The "full" lock is initialized to locked, so that the sending thread   
    146     # will sleep until there is data to send (the lock will be released by the  
    147     # compression thread after it places data in the send buffer) 
    148     self._full_lock[repr(sockobj)].acquire(True) 
    149  
    150     # FF: Create the sending thread for this connection 
    151     def _send_wrapper(): 
    152       self._send_pending_data(sockobj) 
    153     createthread(_send_wrapper)   
    154  
    155     return (remote_ip, remote_port, ShimSocket(repy_socket, self)) 
    156  
    157  
    158  
     481    return (remote_ip, remote_port, CompressionSocket(repy_socket, self)) 
    159482 
    160483 
    161484  def openconnection(self, destip, destport, localip, localport, timeout): 
    162485    # Call the next layer of openconnection. 
    163     next_sockobj = self.get_next_shim_layer().openconnection(destip, destport, localip, localport, timeout) 
    164  
    165     if isinstance(next_sockobj, ShimSocket): 
    166       sockobj = next_sockobj._socket 
    167     else: 
    168       sockobj = next_sockobj 
    169  
    170     # Initialize all the necessary dictionaries and sets. 
    171     self._send_buf_dict[repr(sockobj)] = '' 
    172     self._recv_buf_dict[repr(sockobj)] = '' 
    173     self._recv_exception_dict[repr(sockobj)] = None 
    174     self._result_buf_dict[repr(sockobj)] = '' 
    175     self._active_recv_socket_set.add(repr(sockobj)) 
    176  
    177     self._full_lock[repr(sockobj)] = createlock() 
    178     self._empty_lock[repr(sockobj)] = createlock() 
    179     self._mutex_lock[repr(sockobj)] = createlock() 
    180  
    181     # FF: The "full" lock is initialized to locked, so that the sending thread   
    182     # will sleep until there is data to send (the lock will be released by the  
    183     # compression thread after it places data in the send buffer) 
    184     self._full_lock[repr(sockobj)].acquire(True) 
    185  
    186     # FF: Create the sending thread for this connection 
    187     def _send_wrapper(): 
    188       self._send_pending_data(sockobj) 
    189     createthread(_send_wrapper)   
    190  
    191     return ShimSocket(next_sockobj, self) 
    192  
    193  
    194  
    195  
    196  
     486    repy_socket = self.get_next_shim_layer().openconnection(destip, destport, localip, localport, timeout) 
     487 
     488    return CompressionSocket(repy_socket, self) 
    197489 
    198490 
     
    201493  # =================================================== 
    202494 
    203  
    204  
    205   def _send_pending_data(self,socket): 
    206     """ 
    207     FF: Sending thread. When there is data in the send buffer (signaled by the "full" 
    208     lock becoming available), it keeps trying to send until the buffer is empty,  
    209     at which point it signals to the compression thread by releasing the "empty" lock. 
    210     """ 
    211  
    212     while True: 
    213       self._full_lock[repr(socket)].acquire(True) 
    214       self._mutex_lock[repr(socket)].acquire(True) 
    215       send_buf = self._send_buf_dict[repr(socket)] 
    216       bytes_to_send = len(send_buf) 
    217        
    218       # Send everything in the send buffer and remove it from the buffer. 
    219       while send_buf: 
    220         sent_bytes = 0 
    221         try: 
    222           sent_bytes = self.get_next_shim_layer().socket_send(socket, send_buf) 
    223         except (SocketWouldBlockError, SocketClosedLocal, SocketClosedRemote), err: 
    224           pass 
    225         send_buf = send_buf[sent_bytes : ] 
    226  
    227       # When the send buffer is empty, release the locks and sleep until there 
    228       # is more data to send 
    229       self._mutex_lock[repr(socket)].release() 
    230       self._empty_lock[repr(socket)].release() 
    231  
    232  
    233   @_compression_shim_atomic_operation 
    234495  def socket_send(self, socket, msg): 
    235     """  
    236     <Purpose> 
    237       FF: Compression thread. Breaks up the original stream into blocks of  
    238       _SEND_BLOCK_SIZE and compresses individual blocks. Prepends the length to  
    239       each block as the header, adds a 'T' to the end of each block (so that I  
    240       don't have the change the recv part) and adds these blocks to the send 
    241       buffer sequentially if the send buffer is empty.  
    242  
    243     <Arguments> 
    244       Same as repy v2 socket API. 
    245  
    246     <Exceptions> 
    247       Same as repy v2 socket API. 
    248  
    249     <Side Effects> 
    250       Same as repy v2 socket API. 
    251  
    252     <Returns> 
    253       Same as repy v2 socket API. 
    254  
    255     """ 
    256     # How much of the original uncompressed data has been sent successfully. 
    257     total_original_bytes_sent = 0 
    258  
    259     # Keep sending the supplied message until no more data to send. 
    260     while msg: 
    261  
    262       # Create a compressed block of data out of the original message. 
    263       uncompressed_block = msg[0 : self._SEND_BLOCK_SIZE] 
    264       msg = msg[len(uncompressed_block) : ] 
    265       block_body = zlib.compress(uncompressed_block) 
    266  
    267       # Set the boolean tag as 'T'. 
    268       block_body += 'T' 
    269  
    270       # Append header information to indicate the length of the block. 
    271       block_header = str(len(block_body)) + ',' 
    272       block_data = block_header + block_body 
    273  
    274  
    275       # If the send buffer is empty, place this block in the send buffer 
    276       empty = self._empty_lock[repr(socket)].acquire(False) 
    277       mutex = self._mutex_lock[repr(socket)].acquire(False) 
    278       if mutex and empty: 
    279         self._send_buf_dict[repr(socket)] = block_data 
    280         self._mutex_lock[repr(socket)].release() 
    281         self._full_lock[repr(socket)].release() 
    282         total_original_bytes_sent += len(uncompressed_block) 
    283       else: 
    284         # Release any lock we may have successfully acquired: 
    285         if empty: 
    286           self._empty_lock[repr(socket)].release() 
    287         if mutex: 
    288           self._mutex_lock[repr(socket)].release() 
    289         break 
    290  
    291     # If we have not sent any data, the system's send buffer must be full. 
    292     if total_original_bytes_sent == 0 and len(msg) > 0: 
    293       raise SocketWouldBlockError 
    294  
    295     return total_original_bytes_sent 
    296          
    297        
    298  
    299  
    300  
    301  
    302   @_compression_shim_atomic_operation 
     496    return socket.send(msg) 
     497 
     498 
    303499  def socket_recv(self, socket, bytes): 
    304     """  
    305     <Purpose> 
    306       Receives as much as possible into the receive buffer until the socket 
    307       blocks. 
    308  
    309       Then, from the receive buffer, we reconstruct all the completely received 
    310       blocks. A complete block is a string in the form of "n,msg", where msg is 
    311       the compressed message and n is its length. 
    312  
    313       For each complete block, we decompress the message and add it to the 
    314       result buffer, which stores the original TCP stream. We return up to the 
    315       requested number of bytes from the result buffer. If the result buffer is 
    316       empty, we raise the blocking exception. 
    317  
    318     <Arguments> 
    319       Same as repy v2 socket API. 
    320  
    321     <Exceptions> 
    322       Same as repy v2 socket API. 
    323  
    324     <Side Effects> 
    325       Same as repy v2 socket API. 
    326  
    327     <Returns> 
    328       Same as repy v2 socket API. 
    329  
    330     """ 
    331  
    332     # Get the result buffer out of the dictionary. 
    333     result_buf = self._result_buf_dict[repr(socket)] 
    334  
    335     # If our buffer already has data, then we just return it. 
    336     if len(result_buf) > bytes: 
    337       requested_data = result_buf[0 : bytes] 
    338       self._result_buf_dict[repr(socket)] = result_buf[len(requested_data) : ] 
    339  
    340       return requested_data 
    341  
    342  
    343     # Receive as much as possible into the receive buffer, as long as the socket 
    344     # is active (i.e. not closed remotely or locally). 
    345     while repr(socket) in self._active_recv_socket_set: 
    346  
    347       try: 
    348         self._recv_buf_dict[repr(socket)] += self.get_next_shim_layer().socket_recv(socket, self._RECV_BLOCK_SIZE) 
    349  
    350         # If we have atleast lenght of bytes of compressed data, then most likely we 
    351         # already have bytes length uncompressed data. We do this so the application  
    352         # doesn't have to wait for CompressionShim to block before returning data. 
    353         if len(self._recv_buf_dict[repr(socket)]) > bytes: 
    354           break 
    355       # No more data to read from socket. 
    356       except SocketWouldBlockError: 
    357         break 
    358  
    359       # If a different exception occur, we save it first. We will raise it later 
    360       # when we run out of data to return (i.e. empty result buffer). The socket 
    361       # is now considered inactive. We remove it from the active socket set. 
    362       except (SocketClosedLocal, SocketClosedRemote), err: 
    363         self._recv_exception_dict[repr(socket)] = err 
    364         self._active_recv_socket_set.remove(repr(socket)) 
    365         break 
    366  
    367       # end-try 
    368  
    369     # end-while 
    370  
    371  
    372     # Reconstruct all the blocks of compressed messages from the raw TCP 
    373     # stream we received in the receive buffer. For each block, decompress 
    374     # it and add it to the result buffer. 
    375     while True: 
    376       compressed_block = self._reconstruct_blocks(socket) 
    377       if compressed_block is None: 
    378         break 
    379       elif len(compressed_block) > 0: 
    380         result_buf += zlib.decompress(compressed_block) 
    381  
    382     # If there is nothing in the result buffer, we have received all the data. 
    383     if result_buf == '': 
    384  
    385       # If we have saved exceptions, we raise them now, as these exceptions 
    386       # occurred at the end of the received stream. 
    387       if self._recv_exception_dict[repr(socket)]: 
    388         raise self._recv_exception_dict[repr(socket)] 
    389  
    390       # We simply run out of data without any other exceptions. 
    391       else: 
    392         raise SocketWouldBlockError 
    393  
    394     # Remove the portion requested by the application and return it. 
    395     requested_data = result_buf[0 : bytes] 
    396     self._result_buf_dict[repr(socket)] = result_buf[len(requested_data) : ] 
    397  
    398     return requested_data 
    399  
    400  
    401  
    402  
    403  
    404  
    405   def _reconstruct_blocks(self, socket): 
    406     """ 
    407     Helper method for the socket_recv method. Reconstructs and returns the 
    408     leftmost complete block starting from the head of the receive buffer. If 
    409     there are no more blocks to reconstruct, returns None. Returns an empty 
    410     string if the block is to be discarded. 
    411  
    412     For instance, we may have received the following stream into the 
    413     receive buffer (the '>' sign denotes the beginning of the receive buffer): 
    414      
    415     > 11,HelloWorldT6,pandaT7,googxxF10,micxxx 
    416  
    417     We start from the beginning and parse the header. We read 11 bytes. The last 
    418     byte is a 'T', so we know the block is complete. We have reconstructed 
    419     'HelloWorld' and we move on (because this method is probably used in a while 
    420     loop). Now the buffer becomes: 
    421  
    422     > 6,pandaT7,googxxF10,micxxx 
    423  
    424     Similarly, we are able to reconstruct the block 'panda'. The header 
    425     subsequently looks like: 
    426  
    427     > 7,googxxF10,micxxx 
    428  
    429     We can also reconstruct the new block as 'googxx', but since the tag is an 
    430     'F', we reject and discard the block. Now, the buffer becomes: 
    431  
    432     >10,micxxx 
    433  
    434     We expect that the block has 10 bytes of data, but the buffer ends before 
    435     that. We assume more data is coming in, so we leave the data on the buffer 
    436     and return all the good blocks we have read so far as 'HelloWorld' and 
    437     'panda'. 
    438  
    439     """ 
    440     # Get the receive buffer.  
    441     recv_buf = self._recv_buf_dict[repr(socket)] 
    442  
    443     # Base case: empty receive buffer. Return no blocks. 
    444     if recv_buf == '': 
    445       return None 
    446  
    447     # Stores the length of the block as a string. 
    448     block_length_str = '' 
    449  
    450     # Position in the receive buffer to be read. 
    451     cur_position = 0 
    452  
    453  
    454     # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #  
    455     # Parse the header 
    456     # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #  
    457  
    458     while True: 
    459  
    460       # Attempt to read the header of the block. Read a character at a 
    461       # time until comma or 'F'. 
    462       try: 
    463         char = recv_buf[cur_position] 
    464  
    465       # We still haven't obtained the complete header despite reaching the end 
    466       # of the buffer. Hopefully, we will receive more data into the buffer to 
    467       # form a complete header. For now, there is nothing we can do. We keep all 
    468       # the data in the buffer and exit. 
    469       except IndexError: 
    470         return None 
    471  
    472       cur_position += 1 
    473  
    474       # The correct header should end with a comma. Now that we have 
    475       # successfully read the header, we parse the length. 
    476       if char == ',': 
    477         block_length = int(block_length_str) 
    478         break 
    479  
    480       # The header has an 'F', so a SocketWouldBlockError must have occurred as 
    481       # the header is being sent (e.g. '3F'). We discard this block and retry 
    482       # from a position after the 'F' tag. 
    483       elif char == 'F': 
    484         self._recv_buf_dict[repr(socket)] = recv_buf[cur_position : ] 
    485         return '' 
    486  
    487       # The character is neither a comma or 'F', so we must be still reading the 
    488       # integers in the header. 
    489       elif char.isdigit(): 
    490         block_length_str += char 
    491  
    492       # There must have been a bug! 
    493       else: 
    494         err_str = 'CompressionShim: Invalid characer at position ' + str(cur_position)  
    495         err_str += ' in recv buffer: ' + str(recv_buf) 
    496         raise ShimInternalError(err_str) 
    497      
    498  
    499     # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #  
    500     # Reconstruct block 
    501     # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #  
    502      
    503     # It is possible that the we are still in the middle of receiving the block, 
    504     # and that we only have a part of the block in the receive buffer. 
    505     # Hopefully, we will receive more data into the buffer to form a complete 
    506     # block. For now, there is nothing we can do. We keep all the data in the 
    507     # buffer and exit. 
    508     if cur_position + block_length > len(recv_buf): 
    509       return None 
    510  
    511     # At this point, we know that we have a complete block. Read block_length of 
    512     # bytes from the receive buffer to form a block and remove it from the 
    513     # buffer. 
    514     block_body = recv_buf[cur_position : cur_position + block_length] 
    515     cur_position += block_length 
    516  
    517     # If it is a bad block, we remove it from the buffer, ignore it and retry 
    518     # after the 'F' tag. 
    519     if block_body.endswith('F'): 
    520       self._recv_buf_dict[repr(socket)] = recv_buf[cur_position : ] 
    521       return '' 
    522  
    523     # It's a good block! Remove it from the buffer, return it (minus the 'T') 
    524     # and keep reading the other blocks recursively. 
    525     elif block_body.endswith('T'): 
    526       self._recv_buf_dict[repr(socket)] = recv_buf[cur_position : ] 
    527       return block_body[0 : len(block_body) - 1] 
    528  
    529     # The block should not end with anything else! 
    530     else: 
    531       err_str = 'CompressionShim: Invalid block "' + block_body + '" in buffer: ' + recv_buf 
    532       raise ShimInternalError(err_str) 
    533  
    534  
    535  
    536  
    537  
    538  
    539  
    540  
    541  
    542  
    543  
    544  
    545  
    546  
     500    return socket.recv(bytes) 
     501 
     502 
     503  def socket_close(self, socket): 
     504    return socket.close() 
    547505 
    548506 
     
    579537    else: 
    580538      return 0 
    581  
    582  
    583  
    584  
    585539 
    586540 
     
    622576        raise SocketWouldBlockError("No UDP messages available now")   
    623577 
    624  
  • seattle/branches/repy_v2/shims/hidesizeshim.repy

    r6044 r6156  
    2323dy_import_module_symbols('random') 
    2424 
    25 _hidesize_shim_global_lock = createlock() 
    26  
    27 def _hidesize_shim_atomic_operation(target_func): 
    28   def wrapper(*args, **kwargs): 
     25 
     26 
     27 
     28def _hidesize_socket_atomic_operation(target_func): 
     29  def wrapper(self, *args, **kwargs): 
    2930    try: 
    30       _hidesize_shim_global_lock.acquire(True) 
    31       return target_func(*args, **kwargs) 
     31      self._atomic_lock.acquire(True) 
     32      return target_func(self, *args, **kwargs) 
    3233    finally: 
    33       _hidesize_shim_global_lock.release() 
     34      self._atomic_lock.release() 
    3435 
    3536  return wrapper 
     
    3738 
    3839 
    39 class HideSizeShim(BaseShim): 
    40  
    41   def __init__(self, shim_stack, optional_args=None): 
    42     """ 
    43     In the optional arguments, you can specify the maximum number of bytes to  
    44     append on each send call. For example, if the shim string is  
    45     '(HideSizeShim,1024)' then a random number of bytes up to 1024  
    46     will be appended to each chunk of data that is sent. 
    47  
    48     If it is not specified, then the default of 1024 will be used.  
    49  
    50     """ 
    51  
    52     # A dictionary that maps a socket to its send buffer, which stores any 
    53     # temporary data pending to be transmitted. 
    54     self._send_buf_dict = {} 
    55  
    56     # A dictionary that maps a socket to its receive buffer, which stores 
    57     # the 'bloated' TCP stream. 
    58     self._recv_buf_dict = {} 
    59  
    60     # A dictionary that maps a socket to any exception it raises while we try to 
    61     # receive as much as possible. 
    62     self._recv_exception_dict = {} 
    63  
    64     # A dictionary that maps a socket to the stream of unprocessed data. 
    65     self._result_buf_dict = {} 
    66  
    67     # A set that stores all the active sockets used in socket_send. If a socket 
    68     # raises an exception other than SocketWouldBlockError, it is considered 
    69     # inactive and thus removed from the set. 
    70     self._active_recv_socket_set = set() 
    71  
    72     # A dictionary that maps a socket to the locks used to synchronize the  
    73     # socket's sending thread and 'bloating' thread 
    74     self._full_lock = {} 
    75     self._empty_lock = {} 
    76     self._mutex_lock = {} 
    77  
    78  
    79     # If optional args is provided, use it to limit the number of bytes that 
    80     # will be appended. 
    81     if optional_args: 
    82       max_size = int(optional_args[0]) 
    83       self._MAX_BYTES = max_size 
    84     else: 
    85       self._MAX_BYTES = 1024 
    86  
    87     BaseShim.__init__(self, shim_stack, optional_args) 
    88  
    89  
    90  
    91   def copy(self): 
    92     return HideSizeShim(self.shim_context['shim_stack'].copy(), self.shim_context['optional_args']) 
    93  
    94  
    95    
    96   def get_advertisement_string(self): 
    97  
    98     optional_args = self.shim_context['optional_args'] 
    99     shim_name = '(HideSizeShim' 
    100  
    101     if optional_args: 
    102       shim_name += ',' + str(optional_args[0]) + ')' 
    103     else: 
    104       shim_name += ')' 
    105  
    106     return shim_name + self.get_next_shim_layer().get_advertisement_string() 
    107  
    108  
    109  
    110  
    111   def tcpserversocket_getconnection(self, tcpserversocket): 
    112     """ 
    113     Sets up dictionaries, sets, locks, threads, and other necessary things 
    114  
    115     """ 
    116  
    117     # Call the next layer of tcpserver_getconnection() 
    118     (remote_ip, remote_port, repy_socket) = self.get_next_shim_layer().tcpserversocket_getconnection(tcpserversocket) 
    119  
    120     if isinstance(repy_socket, ShimSocket): 
    121       sockobj = repy_socket._socket 
    122     else: 
    123       sockobj = repy_socket 
     40 
     41class HideSizeSocket(): 
     42 
     43  def __init__(self, socket, shim): 
     44    self.socket = socket 
     45    self.shim = shim 
    12446 
    12547    # Initialize all the necessary dictionaries and sets. 
    126     self._send_buf_dict[repr(sockobj)] = '' 
    127     self._recv_buf_dict[repr(sockobj)] = '' 
    128     self._recv_exception_dict[repr(sockobj)] = None 
    129     self._result_buf_dict[repr(sockobj)] = '' 
    130     self._active_recv_socket_set.add(repr(sockobj)) 
    131  
    132     # locks for synchronizing threads 
    133     self._full_lock[repr(sockobj)] = createlock() 
    134     self._empty_lock[repr(sockobj)] = createlock() 
    135     self._mutex_lock[repr(sockobj)] = createlock() 
    136  
    137  
    138     # The "full" lock is initialized to locked, so that the sending thread   
     48    self._send_buf = '' 
     49    self._recv_buf = '' 
     50    self._result_buf = '' 
     51 
     52    self._full_lock = createlock() 
     53    self._empty_lock = createlock() 
     54    self._mutex_lock = createlock() 
     55    self._atomic_lock = createlock() 
     56 
     57    self.closed_local = False 
     58    self.closed_remote = False 
     59    self.close_lock = createlock() 
     60 
     61    # FF: The "full" lock is initialized to locked, so that the sending thread   
    13962    # will sleep until there is data to send (the lock will be released by the  
    140     # bloating thread after it places data in the send buffer) 
    141     self._full_lock[repr(sockobj)].acquire(True) 
    142  
    143     # Create the sending thread for this connection 
    144     def _send_wrapper(): 
    145       self._send_pending_data(sockobj) 
    146     createthread(_send_wrapper)   
    147  
    148     return (remote_ip, remote_port, ShimSocket(repy_socket, self)) 
    149  
    150  
    151  
    152   def openconnection(self, destip, destport, localip, localport, timeout): 
    153     """ 
    154     Sets up dictionaries, sets, locks, threads, and other necessary things 
    155  
    156     """ 
    157  
    158     # Call the next layer of openconnection. 
    159     next_sockobj = self.get_next_shim_layer().openconnection(destip, destport, localip, localport, timeout) 
    160  
    161     if isinstance(next_sockobj, ShimSocket): 
    162       sockobj = next_sockobj._socket 
    163     else: 
    164       sockobj = next_sockobj 
    165  
    166     # Initialize all the necessary dictionaries and sets. 
    167     self._send_buf_dict[repr(sockobj)] = '' 
    168     self._recv_buf_dict[repr(sockobj)] = '' 
    169     self._recv_exception_dict[repr(sockobj)] = None 
    170     self._result_buf_dict[repr(sockobj)] = '' 
    171     self._active_recv_socket_set.add(repr(sockobj)) 
    172  
    173     self._full_lock[repr(sockobj)] = createlock() 
    174     self._empty_lock[repr(sockobj)] = createlock() 
    175     self._mutex_lock[repr(sockobj)] = createlock() 
    176  
    177     # The "full" lock is initialized to locked, so that the sending thread   
    178     # will sleep until there is data to send (the lock will be released by the  
    179     # bloating thread after it places data in the send buffer) 
    180     self._full_lock[repr(sockobj)].acquire(True) 
    181  
    182     # Create the sending thread for this connection 
    183     def _send_wrapper(): 
    184       self._send_pending_data(sockobj) 
    185     createthread(_send_wrapper)   
    186  
    187     return ShimSocket(next_sockobj, self) 
    188  
    189  
    190  
    191   def _send_pending_data(self,socket): 
    192     """ 
    193     Sending thread. When there is data in the send buffer (signaled by the "full" 
    194     lock becoming available), it keeps trying to send until the buffer is empty,  
    195     at which point it signals to the bloating thread by releasing the "empty" lock. 
    196     """ 
    197  
    198     while True: 
    199       self._full_lock[repr(socket)].acquire(True) 
    200       self._mutex_lock[repr(socket)].acquire(True) 
    201       send_buf = self._send_buf_dict[repr(socket)] 
    202       bytes_to_send = len(send_buf) 
    203        
    204       # Send everything in the send buffer and remove it from the buffer. 
    205       while send_buf: 
    206         sent_bytes = 0 
    207         try: 
    208           sent_bytes = self.get_next_shim_layer().socket_send(socket, send_buf) 
    209         except (SocketWouldBlockError, SocketClosedLocal, SocketClosedRemote), err: 
    210           # ToDo: we shouldn't just pass 
    211           pass 
    212         send_buf = send_buf[sent_bytes : ] 
    213  
    214       # When the send buffer is empty, release the locks and sleep until there 
    215       # is more data to send 
    216       self._mutex_lock[repr(socket)].release() 
    217       self._empty_lock[repr(socket)].release() 
    218  
    219  
    220   def socket_send(self, socket, msg): 
     63    # compression thread after it places data in the send buffer) 
     64    self._full_lock.acquire(True) 
     65 
     66    # FF: Create the sending thread for this connection 
     67    createthread(self._send_pending_data) 
     68 
     69 
     70  def send(self, msg): 
    22171    """  
    22272      Bloating thread. Picks a random int and gets that many random bytes to  
     
    23989      msg_length = len(msg) 
    24090      # Don't send more "junk" data than the original message data 
    241       extradata_length = min(random_int_below(self._MAX_BYTES),random_int_below(len(msg))) 
     91      extradata_length = min(random_int_below(self.shim._MAX_BYTES),random_int_below(len(msg))) 
    24292      extradata = ''.join(random_sample(list(msg),extradata_length)) 
    24393 
     
    251101 
    252102      # If the send buffer is empty, place this block in the send buffer 
    253       empty = self._empty_lock[repr(socket)].acquire(False) 
    254       mutex = self._mutex_lock[repr(socket)].acquire(False) 
     103      empty = self._empty_lock.acquire(False) 
     104      mutex = self._mutex_lock.acquire(False) 
    255105      if mutex and empty: 
    256         self._send_buf_dict[repr(socket)] = block_data 
    257         self._mutex_lock[repr(socket)].release() 
    258         self._full_lock[repr(socket)].release() 
     106        self._send_buf = block_data 
     107        self._mutex_lock.release() 
     108        self._full_lock.release() 
    259109        total_original_bytes_sent += len(msg) 
    260110      else: 
    261111        # Release any lock we may have successfully acquired: 
    262112        if empty: 
    263           self._empty_lock[repr(socket)].release() 
     113          self._empty_lock.release() 
    264114        if mutex: 
    265           self._mutex_lock[repr(socket)].release() 
     115          self._mutex_lock.release() 
    266116        break 
    267117 
     
    277127 
    278128 
    279   @_hidesize_shim_atomic_operation 
    280   def socket_recv(self, socket, bytes): 
     129  @_hidesize_socket_atomic_operation 
     130  def recv(self, bytes): 
    281131    """  
    282132      Receive into the receive buffer until the socket blocks. 
     
    288138 
    289139    # Get the result buffer out of the dictionary. 
    290     result_buf = self._result_buf_dict[repr(socket)] 
     140    result_buf = self._result_buf 
    291141 
    292142    # If result buffer already has enough data in it, then we just return it. 
    293143    if len(result_buf) > bytes: 
    294144      requested_data = result_buf[0 : bytes] 
    295       self._result_buf_dict[repr(socket)] = result_buf[len(requested_data) : ] 
     145      self._result_buf = result_buf[len(requested_data) : ] 
    296146      return requested_data 
    297147 
    298148    # Otherwise, we need to get fresh data off the wire and process it: 
    299149 
     150    recv_exception = None 
     151 
    300152    # Receive as much as possible into the receive buffer 
    301     while repr(socket) in self._active_recv_socket_set: 
     153    while True: 
    302154      try: 
    303         self._recv_buf_dict[repr(socket)] += self.get_next_shim_layer().socket_recv(socket, 4096) 
     155        self._recv_buf += self.socket.recv(4096) 
    304156 
    305157      # Stop trying to receive when there is no more data to read from the socket 
     
    311163      # is now considered inactive. We remove it from the active socket set. 
    312164      except (SocketClosedLocal, SocketClosedRemote), err: 
    313         self._recv_exception_dict[repr(socket)] = err 
    314         self._active_recv_socket_set.remove(repr(socket)) 
     165        self.closed_remote = True 
     166        recv_exception = err 
    315167        break 
    316  
    317168 
    318169 
     
    320171    # receive buffer, and add it to the result buffer. 
    321172    while True: 
    322       bloated_block = self._reconstruct_blocks(socket) 
     173      bloated_block = self._reconstruct_blocks() 
    323174      if bloated_block is None: 
    324175        break 
     
    331182      # If we have saved exceptions, we raise them now, as these exceptions 
    332183      # occurred at the end of the received stream. 
    333       if self._recv_exception_dict[repr(socket)]: 
    334         raise self._recv_exception_dict[repr(socket)] 
     184      if recv_exception: 
     185        raise recv_exception 
    335186 
    336187      # If we run out of data without any other exceptions. 
     
    341192    # and remove from result buffer 
    342193    requested_data = result_buf[0 : bytes] 
    343     self._result_buf_dict[repr(socket)] = result_buf[len(requested_data) : ] 
     194    self._result_buf = result_buf[len(requested_data) : ] 
    344195 
    345196    return requested_data 
    346197 
    347198 
    348  
    349   def _reconstruct_blocks(self, socket): 
     199  def close(self): 
     200    self.close_lock.acquire(True) 
     201    result = not self.closed_local 
     202    self.closed_local = True 
     203    self.close_lock.release() 
     204 
     205    try: 
     206      self._full_lock.release() 
     207    except Exception: 
     208      pass 
     209 
     210    return result 
     211 
     212 
     213 
     214  def _reconstruct_blocks(self): 
    350215    """ 
    351216    Helper method for the socket_recv method. Reconstructs and returns the 
     
    355220    """ 
    356221    # Get the receive buffer.  
    357     recv_buf = self._recv_buf_dict[repr(socket)] 
     222    recv_buf = self._recv_buf 
    358223 
    359224    # Return no blocks if the buffer is empty 
     
    413278    cur_position += msg_length 
    414279    cur_position += extradata_length 
    415     self._recv_buf_dict[repr(socket)] = recv_buf[cur_position : ] 
     280    self._recv_buf = recv_buf[cur_position : ] 
    416281    return block_body[0 : msg_length ] 
    417282 
    418283 
    419  
     284  def _send_pending_data(self): 
     285    """ 
     286    FF: Sending thread. When there is data in the send buffer (signaled by the "full" 
     287    lock becoming available), it keeps trying to send until the buffer is empty,  
     288    at which point it signals to the compression thread by releasing the "empty" lock. 
     289    """ 
     290 
     291    while True: 
     292      self._full_lock.acquire(True) 
     293      self._mutex_lock.acquire(True) 
     294 
     295      try: 
     296        send_buf = self._send_buf 
     297        bytes_to_send = len(send_buf) 
     298         
     299        # Send everything in the send buffer and remove it from the buffer. 
     300        while send_buf: 
     301          sent_bytes = 0 
     302          try: 
     303            sent_bytes = self.socket.send(send_buf) 
     304          except SocketWouldBlockError, err: 
     305            pass 
     306          except (SocketClosedLocal, SocketClosedRemote), err: 
     307            self.closed_remote = True 
     308            return 
     309          send_buf = send_buf[sent_bytes : ] 
     310 
     311        if self.closed_local: 
     312          self.socket.close() 
     313          return 
     314 
     315      finally: 
     316        # When the send buffer is empty, release the locks and sleep until there 
     317        # is more data to send 
     318        self._mutex_lock.release() 
     319        try: 
     320          self._empty_lock.release() 
     321        except Exception: 
     322          pass 
     323 
     324 
     325 
     326 
     327class HideSizeShim(BaseShim): 
     328 
     329  def __init__(self, shim_stack, optional_args=None): 
     330    """ 
     331    In the optional arguments, you can specify the maximum number of bytes to  
     332    append on each send call. For example, if the shim string is  
     333    '(HideSizeShim,1024)' then a random number of bytes up to 1024  
     334    will be appended to each chunk of data that is sent. 
     335 
     336    If it is not specified, then the default of 1024 will be used.  
     337 
     338    """ 
     339 
     340    # If optional args is provided, use it to limit the number of bytes that 
     341    # will be appended. 
     342    if optional_args: 
     343      max_size = int(optional_args[0]) 
     344      self._MAX_BYTES = max_size 
     345    else: 
     346      self._MAX_BYTES = 1024 
     347 
     348    BaseShim.__init__(self, shim_stack, optional_args) 
     349 
     350 
     351 
     352  def copy(self): 
     353    return HideSizeShim(self.shim_context['shim_stack'].copy(), self.shim_context['optional_args']) 
     354 
     355 
     356   
     357  def get_advertisement_string(self): 
     358 
     359    optional_args = self.shim_context['optional_args'] 
     360    shim_name = '(HideSizeShim' 
     361 
     362    if optional_args: 
     363      shim_name += ',' + str(optional_args[0]) + ')' 
     364    else: 
     365      shim_name += ')' 
     366 
     367    return shim_name + self.get_next_shim_layer().get_advertisement_string() 
     368 
     369 
     370 
     371 
     372  def tcpserversocket_getconnection(self, tcpserversocket): 
     373    """ 
     374    Sets up dictionaries, sets, locks, threads, and other necessary things 
     375 
     376    """ 
     377 
     378    # Call the next layer of tcpserver_getconnection() 
     379    (remote_ip, remote_port, repy_socket) = self.get_next_shim_layer().tcpserversocket_getconnection(tcpserversocket) 
     380 
     381    return (remote_ip, remote_port, HideSizeSocket(repy_socket, self)) 
     382 
     383 
     384 
     385  def openconnection(self, destip, destport, localip, localport, timeout): 
     386    """ 
     387    Sets up dictionaries, sets, locks, threads, and other necessary things 
     388 
     389    """ 
     390 
     391    # Call the next layer of openconnection. 
     392    repy_socket = self.get_next_shim_layer().openconnection(destip, destport, localip, localport, timeout) 
     393 
     394    return HideSizeSocket(repy_socket, self) 
     395 
     396 
     397  def socket_send(self, socket, msg): 
     398    return socket.send(msg) 
     399 
     400 
     401  def socket_recv(self, socket, bytes): 
     402    return socket.recv(bytes) 
     403 
     404 
     405  def socket_close(self, socket): 
     406    return socket.close() 
     407 
  • seattle/branches/repy_v2/shims/mobile_socket.repy

    r5051 r6156  
    1515    logstr = 't = %.2f ' % getruntime() + logstr 
    1616    f = openfile('mobility.log', True) 
    17     f.writeat(logstr, _debug_log_context['written']) 
    18     _debug_log_context['written'] += len(logstr) 
     17    f.writeat(logstr, _debug_log_context['bytes_written']) 
     18    _debug_log_context['bytes_written'] += len(logstr) 
    1919    f.close() 
    2020  _debug_log_context['lock'].release() 
  • seattle/branches/repy_v2/shims/mobilityshim.repy

    r5051 r6156  
    526526            return MobileSocket(shim_sock, [destip, destport, localip, localport, original_timeout], connection_id) 
    527527 
    528           # The remote server socket is listening for reconnections, but 
    529           # is closed, so we shouldn't be able to connect. 
    530           elif message == "D": 
     528          else: 
     529            # The remote server socket is listening for reconnections, but 
     530            # is closed, so we shouldn't be able to connect. 
     531            if message != "D": 
     532              debug_log('zzzz Openconn received bad connection id response: ', message, '\n', show=True) 
    531533            try: 
    532534              shim_sock.close() 
     
    535537            raise ConnectionRefusedError("The connection was refused!") 
    536538 
    537           else: 
    538             debug_log('zzzz Openconn received bad connection id response: ', message, '\n', show=True) 
    539  
    540539        try: 
    541540          shim_sock.close() 
  • seattle/branches/repy_v2/shims/msg_chunk_lib.repy

    r5051 r6156  
    695695                pass 
    696696 
    697           else: 
    698             for socket in self._complete_socket_set.copy(): 
    699               try: 
    700                 socket.close() 
    701               except: 
    702                 pass 
    703  
    704             self._active_socket_set.clear() 
     697          for socket in self._complete_socket_set.copy(): 
     698            try: 
     699              socket.close() 
     700            except: 
     701              pass 
     702 
     703          self._active_socket_set.clear() 
    705704 
    706705      finally: 
  • seattle/branches/repy_v2/shims/multipathshim.repy

    r4961 r6156  
    361361            continue 
    362362 
     363          serversocket = self.tcpsocket_dict[cur_tcp_socket]['tcpsockobj'] 
     364 
     365          if not serversocket: 
     366            continue 
     367 
    363368          try: 
    364369            # Try to connect to this tcp socket 
    365             (remoteip, remoteport, sockobj) = self.tcpsocket_dict[cur_tcp_socket]['tcpsockobj'].getconnection() 
     370            (remoteip, remoteport, sockobj) = serversocket.getconnection() 
    366371          except (SocketWouldBlockError, ResourceExhaustedError): 
    367372            # If there is any exception we don't raise it because we have already made 
     
    542547        sockobj = self.shim_stack_dict[cur_shim_stack].peek().openconnection(destip, destport,  
    543548                                                                             localip, localport, timeout) 
    544       except (CleanupInProgressError, TimeoutError, InternetConnectivityError, ConnectionRefusedError), err: 
     549      except (NetworkError, TimeoutError), err: 
    545550        exception_list.append(err) 
    546551        pass 
     
    592597       
    593598      while not all_connected: 
    594         all_connected = True       
     599        all_connected = True 
     600 
     601        for cur_shim_stack in connection_dict.keys(): 
     602          try: 
     603            # Check to see if we are already connected. 
     604            if connection_dict[cur_shim_stack]['sockobj']: 
     605              continue 
     606            try: 
     607              sockobj = self.shim_stack_dict[cur_shim_stack].peek().openconnection(self._destip, 
     608                            connection_dict[cur_shim_stack]['destport'], self._localip, 
     609                            connection_dict[cur_shim_stack]['localport'], self._timeout) 
    595610         
    596         for cur_shim_stack in self.shim_stack_dict.keys(): 
    597           # Check to see if we are already connected. 
    598           if connection_dict[cur_shim_stack]['sockobj']: 
    599             continue 
    600           try: 
    601             sockobj = self.shim_stack_dict[cur_shim_stack].peek().openconnection(self._destip, 
    602                           connection_dict[cur_shim_stack]['destport'], self._localip, 
    603                           connection_dict[cur_shim_stack]['localport'], self._timeout) 
    604        
    605           except (CleanupInProgressError, TimeoutError, InternetConnectivityError, ConnectionRefusedError), err: 
    606             # If there is any exception we don't raise it because we have already made 
    607             # one successful connection. 
    608             # If we have gotten an exception then there is atleast one shimstack that 
    609             # we haven't been able to connect to. 
    610             all_connected = False 
    611           else: 
    612             connection_dict[cur_shim_stack]['sockobj'] = sockobj 
    613              
     611            except (NetworkError, TimeoutError), err: 
     612              # If there is any exception we don't raise it because we have already made 
     613              # one successful connection. 
     614              # If we have gotten an exception then there is atleast one shimstack that 
     615              # we haven't been able to connect to. 
     616              all_connected = False 
     617            else: 
     618              connection_dict[cur_shim_stack]['sockobj'] = sockobj 
     619 
     620          except KeyError: 
     621            pass 
    614622 
    615623    return _openconn_thread_helper 
     
    654662 
    655663    for cur_shim_stack in shim_stack_keys: 
     664      if cur_shim_stack not in tcpsocket_dict.keys(): 
     665        tcpsocket_dict[cur_shim_stack] = {} 
     666 
     667      tcpsocket_dict[cur_shim_stack]['localport'] = localport 
     668      tcpsocket_dict[cur_shim_stack]['tcpsockobj'] = None 
     669 
    656670      try: 
    657671        # Call listenforconnection on the top shim of this particular shimstack. 
     
    661675        pass 
    662676      else: 
    663         if cur_shim_stack not in tcpsocket_dict.keys(): 
    664           tcpsocket_dict[cur_shim_stack] = {} 
    665  
    666         tcpsocket_dict[cur_shim_stack]['localport'] = localport 
    667         tcpsocket_dict[cur_shim_stack]['tcpsockobj'] = None 
    668  
    669677        # Once connection has been made, add the tcpsocket object to connection_dict 
    670678        tcpsocket_dict[cur_shim_stack]['tcpsockobj'] = tcpsockobj 
     
    708716        all_connected = True   
    709717 
    710         for cur_shim_stack in self.shim_stack_dict.keys(): 
    711           if cur_shim_stack not in tcpsocket_dict.keys(): 
    712             tcpsocket_dict[cur_shim_stack] = {} 
    713             tcpsocket_dict[cur_shim_stack]['tcpsockobj'] = None 
    714             tcpsocket_dict[cur_shim_stack]['localport'] = 0 
    715  
    716           # Check to see if we are already connected. 
    717           if tcpsocket_dict[cur_shim_stack]['tcpsockobj']: 
    718             continue 
    719  
     718        for cur_shim_stack in tcpsocket_dict.keys(): 
    720719          try: 
    721             tcpsockobj = self.shim_stack_dict[cur_shim_stack].peek().listenforconnection(self._localip, 
    722                              tcpsocket_dict[cur_shim_stack]['localport']) 
    723  
    724           except (AlreadyListeningError, DuplicateTupleError): 
    725             # If there is any exception we don't raise it because we have already made 
    726             # one successful connection. 
    727             # If we have gotten an exception then there is atleast one shimstack that 
    728             # we haven't been able to connect to. 
    729             all_connected = False 
    730           else: 
    731             tcpsocket_dict[cur_shim_stack]['tcpsockobj'] = tcpsockobj 
    732  
    733  
    734     return _listenforconnection_helper  
     720            # Check to see if we are already connected. 
     721            if tcpsocket_dict[cur_shim_stack]['tcpsockobj']: 
     722              continue 
     723 
     724            try: 
     725              tcpsockobj = self.shim_stack_dict[cur_shim_stack].peek().listenforconnection(self._localip, 
     726                               tcpsocket_dict[cur_shim_stack]['localport']) 
     727 
     728            except (AlreadyListeningError, DuplicateTupleError): 
     729              # If there is any exception we don't raise it because we have already made 
     730              # one successful connection. 
     731              # If we have gotten an exception then there is atleast one shimstack that 
     732              # we haven't been able to connect to. 
     733              all_connected = False 
     734            else: 
     735              tcpsocket_dict[cur_shim_stack]['tcpsockobj'] = tcpsockobj 
     736 
     737          except KeyError: 
     738            pass 
     739 
     740    return _listenforconnection_helper 
    735741 
    736742 
  • seattle/branches/repy_v2/shims/natpunchshim.repy

    r4954 r6156  
    102102      try: 
    103103        forwarder_ip, forwarder_port = optional_args[0].split(':') 
    104         self.default_forwarder = (forwarder_ip, int(forwarder_port)) 
     104        self.default_forwarder = [forwarder_ip, int(forwarder_port)] 
    105105      except ValueError: 
    106106        raise ShimInternalError("Optional arg provided is not of valid format. Must be IP:port.") 
     
    108108    # If no default forwarder is provided. 
    109109    else: 
    110       self.default_forwarder = None 
     110      self.default_forwarder = [] 
    111111 
    112112    BaseShim.__init__(self, shim_stack, optional_args) 
     
    416416            if _NAT_SHIM_DEBUG_MODE: 
    417417              log("Setting the default forwarder\n") 
    418             self.default_forwarder = (forwarder_ip, forwarder_port) 
     418            self.default_forwarder.append(forwarder_ip) 
     419            self.default_forwarder.append(forwarder_port) 
    419420 
    420421          if _NAT_SHIM_DEBUG_MODE: 
     
    464465    optional_args_copy = self.shim_context['optional_args'] 
    465466 
    466     return NatPunchShim(shim_stack_copy, optional_args_copy) 
     467    shim_copy = NatPunchShim(shim_stack_copy, optional_args_copy) 
     468    shim_copy.default_forwarder = self.default_forwarder 
     469 
     470    return shim_copy 
    467471 
    468472 
  • seattle/branches/repy_v2/shims/rsashim.repy

    r4691 r6156  
    3333RSA_CHUNK_SIZE = 2**8 
    3434 
    35 class RSAShim(BaseShim): 
    36  
    37   def __init__(self, shim_stack = ShimStack(), optional_args = None): 
    38     """ 
    39     <Purpose> 
    40       Initialize the RSAShim. 
    41  
    42     <Arguments> 
    43       shim_stack - the shim stack underneath us. 
    44  
    45       optional_args - The optional args (if provided) will be used to 
    46         encrypt and decrypt data. A new key will not be generated. Note 
    47         that if optional_args is provided then it must be provided for  
    48         both the server side and client side, otherwise they won't be 
    49         able to communicate. 
    50  
    51     <Side Effects> 
    52       None 
    53  
    54     <Exceptions> 
    55       ShimInternalError raised if the optional args provided is not of 
    56       the proper format. 
    57  
    58     <Return> 
    59       None 
    60     """ 
    61  
    62     # We keep a dictionary of key dicts for each socket that has 
    63     # a connection. 
    64     self.rsa_key_dict = {} 
    65     self.rsa_buffer_context = {} 
    66     self.active_sock_set = set() 
    67     self.socket_closed_local = [] 
    68      
    69     BaseShim.__init__(self, shim_stack, optional_args) 
    70  
    71   
    72  
    73  
    74   # ======================================== 
    75   # TCP section of RSA Shim 
    76   # ======================================== 
    77   def openconnection(self, destip, destport, localip, localport, timeout): 
    78     """ 
    79     <Purpose> 
    80       Create a connection and initiate the handshake with the server. 
    81       Make sure that we have the same method of communication. 
    82  
    83     <Arguments> 
    84       Same arguments as Repy V2 Api for openconnection. 
    85  
    86     <Side Effects> 
    87       Some messages are sent back and forth. 
    88  
    89     <Exceptions> 
    90       Same exceptions as Repy V2 Api for openconnection. Note that 
    91       a ConnectionRefusedError is raised if the handhake fails with 
    92       the server. 
    93  
    94     <Return> 
    95       A socket like object. 
    96     """ 
    97  
    98     # Open a connection by calling the next layer of shim. 
    99     next_sockobj = self.get_next_shim_layer().openconnection(destip, destport, localip, localport, timeout) 
    100  
    101     if isinstance(next_sockobj, ShimSocket): 
    102       sockobj = next_sockobj._socket 
    103     else: 
    104       sockobj = next_sockobj 
    105  
    106     # Generate a new set of pubkey/privkey and send the pub 
    107     # key back to the server to receive the actual key. 
    108     (temp_pub, temp_priv) = rsa_gen_pubpriv_keys(BITSIZE) 
    109      
    110     # Greet the server and send it a temporary pubkey and wait 
    111     # for a response. 
    112     session_sendmessage(sockobj, GREET_TAG + str(temp_pub)) 
    113     encrypted_response = session_recvmessage(sockobj) 
    114  
    115     response = rsa_decrypt(encrypted_response, temp_priv) 
    116      
    117     if response.startswith(NEW_SHARE_TAG): 
    118       key_pair = response.split(NEW_SHARE_TAG)[1] 
    119  
    120       (pub_key, priv_key) = key_pair.split(':::') 
    121  
    122       # Add the socket to the active list. 
    123       self.active_sock_set.add(repr(sockobj)) 
    124  
    125       self.rsa_key_dict[repr(sockobj)] = {} 
    126       self.rsa_key_dict[repr(sockobj)]['pub_key'] = eval(pub_key) 
    127       self.rsa_key_dict[repr(sockobj)]['priv_key'] = eval(priv_key) 
    128  
    129       self.rsa_buffer_context[repr(sockobj)] = {'send_buffer' : '', 
    130                                                 'recv_encrypted_buffer' : '', 
    131                                                 'recv_buffer' : '', 
    132                                                 'send_lock' : createlock(), 
    133                                                 'recv_lock' : createlock(), 
    134                                                 'recv_encrypt_lock' : createlock()} 
    135     else: 
    136       raise ConnectionRefusedError("Unable to complete handshake with server and " + 
    137                                    "agree on RSA key.") 
    138  
    139     # Start up the two threads that does the sending and receiving  
    140     # of the data. 
    141     createthread(self._sending_thread(sockobj)) 
    142     createthread(self._receiving_thread(sockobj)) 
    143  
    144     return ShimSocket(next_sockobj, self) 
    145  
    146  
    147  
    148  
    149   def tcpserversocket_getconnection(self, tcpserversocket): 
    150     """ 
    151     <Purpose> 
    152       Accept a connection from the client. Complete a handshake 
    153       to make sure that both the server and client have the same 
    154       pub/priv key. 
    155     
    156     <Arguments> 
    157       Same arguments as Repy V2 Api for tcpserver.getconnection() 
    158  
    159     <Side Effects> 
    160       Some messages are sent back and forth. Some RSA keys are generated 
    161       so things might slow down. 
    162  
    163     <Return> 
    164       A tuple of remoteip, remoteport and socket like object. 
    165     """ 
    166  
    167     # Call the next layer of socket to get a connection. 
    168     (remoteip, remoteport, next_sockobj) = self.get_next_shim_layer().tcpserversocket_getconnection(tcpserversocket) 
    169  
    170     # We want the actual socket object, not the shimsocket 
    171     if isinstance(next_sockobj, ShimSocket): 
    172       sockobj = next_sockobj._socket 
    173     else: 
    174       sockobj = next_sockobj 
    175  
    176     # Try to get the initial greeting from the connection. 
    177     try: 
    178       initial_msg = session_recvmessage(sockobj) 
    179     except (ValueError, SessionEOF): 
    180       raise SocketWouldBlockError("No connection available right now.") 
    181  
    182      
    183     # If we get a greeting tag then we send back to the client a new set 
    184     # of key that will be used to do all the communication. 
    185     if initial_msg.startswith(GREET_TAG): 
    186       # Extract the pubkey and convert it to dict. 
    187       client_pubkey = eval(initial_msg.split(GREET_TAG)[1]) 
    188  
    189       # Generate new key. 
    190       (pub_key, priv_key) = rsa_gen_pubpriv_keys(BITSIZE) 
    191  
    192       # Add the socket to the active socket set. 
    193       self.active_sock_set.add(repr(sockobj)) 
    194  
    195       self.rsa_key_dict[repr(sockobj)] = {} 
    196       self.rsa_key_dict[repr(sockobj)]['pub_key'] = pub_key 
    197       self.rsa_key_dict[repr(sockobj)]['priv_key'] = priv_key 
    198  
    199       self.rsa_buffer_context[repr(sockobj)] = {'send_buffer' : '', 
    200                                                 'recv_encrypted_buffer' : '', 
    201                                                 'recv_buffer' : '', 
    202                                                 'send_lock' : createlock(), 
    203                                                 'recv_lock' : createlock(), 
    204                                                 'recv_encrypt_lock' : createlock()} 
    205  
    206       # Send back the new set of keys, encrypted with the pubkey 
    207       # provided by the client initially. 
    208       new_msg = NEW_SHARE_TAG + str(pub_key) + ':::' + str(priv_key) 
    209       session_sendmessage(sockobj, rsa_encrypt(new_msg, client_pubkey)) 
    210  
    211     else: 
    212       raise ConnectionRefusedError("Unable to complete handshake with server and " + 
    213                                    "agree on RSA key.") 
    214  
    215  
    216     # Start up the two threads that does the sending and receiving 
    217     # of the data. 
    218     createthread(self._sending_thread(sockobj)) 
    219     createthread(self._receiving_thread(sockobj)) 
    220  
    221     return (remoteip, remoteport, ShimSocket(next_sockobj, self)) 
    222  
    223  
    224  
    225  
    226  
    227   def socket_send(self, socket, msg): 
     35 
     36class RSASocket(): 
     37 
     38  def __init__(self, socket, rsa_key_dict, rsa_buffer_context): 
     39    self.socket = socket 
     40    self.rsa_key_dict = rsa_key_dict 
     41    self.rsa_buffer_context = rsa_buffer_context 
     42 
     43    self.socket_closed_local = False 
     44    self.send_closed_remote = False 
     45    self.recv_closed_remote = False 
     46    self.decrypt_done = False 
     47    self.close_lock = createlock() 
     48 
     49    createthread(self._sending_thread) 
     50    createthread(self._receiving_thread) 
     51 
     52 
     53  def send(self, msg): 
    22854    """ 
    22955    <Purpose> 
     
    24571    # If the send buffer is not empty, then we raise a  
    24672    # SocketWouldBlockError. 
    247     if self.rsa_buffer_context[repr(socket)]['send_buffer']: 
     73    if self.rsa_buffer_context['send_buffer']: 
    24874      raise SocketWouldBlockError("Send buffer is full") 
    24975 
    250     if repr(socket) in self.socket_closed_local: 
     76    if self.socket_closed_local: 
    25177      raise SocketClosedLocal("Socket closed locally!") 
    25278 
    253     if repr(socket) not in self.active_sock_set: 
     79    if self.send_closed_remote: 
    25480      raise SocketClosedRemote("Socket closed remotely!") 
    25581 
    25682    # Keep track of the original msg size, as it will change. 
    25783    original_data_size = len(msg) 
    258  
    25984 
    26085    # Encrypt the data and put a little buffer. 
     
    27297 
    27398      s=getruntime() 
    274       encrypt_msg = rsa_encrypt(sub_msg, self.rsa_key_dict[repr(socket)]['pub_key']) 
     99      encrypt_msg = rsa_encrypt(sub_msg, self.rsa_key_dict['pub_key']) 
    275100      e=getruntime() - s 
    276101      header = str(len(encrypt_msg)) + '\n' 
     
    278103      # Now that we are done encrypting everything, we add it to the send buffer 
    279104      # in order for it to be sent across the network. 
    280       self.rsa_buffer_context[repr(socket)]['send_lock'].acquire(True) 
    281       try: 
    282         self.rsa_buffer_context[repr(socket)]['send_buffer'] += header + encrypt_msg 
     105      self.rsa_buffer_context['send_lock'].acquire(True) 
     106      self.close_lock.acquire(True) 
     107      try: 
     108        if not self.socket_closed_local: 
     109          self.rsa_buffer_context['send_buffer'] += header + encrypt_msg 
     110        else: 
     111          raise SocketClosedLocal("Socket closed locally!") 
    283112      finally: 
    284         self.rsa_buffer_context[repr(socket)]['send_lock'].release() 
    285  
    286  
     113        self.rsa_buffer_context['send_lock'].release() 
     114        self.close_lock.release() 
    287115 
    288116    return original_data_size 
    289117 
    290118 
    291  
    292  
    293  
    294   def socket_recv(self, socket, bytes): 
     119  def recv(self, bytes): 
    295120    """ 
    296121    <Purpose> 
     
    300125 
    301126    # Check first if the socket_close call was called. 
    302     if repr(socket) in self.socket_closed_local: 
     127    if self.socket_closed_local: 
    303128      raise SocketClosedLocal("Socket closed locally!") 
    304129 
     
    312137    # been decrypted and put in the buffer yet. 
    313138 
    314  
    315     if not self.rsa_buffer_context[repr(socket)]['recv_buffer']: 
    316       if repr(socket) not in self.active_sock_set: 
     139    if self.decrypt_done: 
     140      if not self.rsa_buffer_context['recv_buffer']: 
    317141        raise SocketClosedRemote("Socket closed remotely!") 
    318       else: 
    319         raise SocketWouldBlockError("No data to be received.") 
    320  
    321  
    322      
     142    elif not self.rsa_buffer_context['recv_buffer']: 
     143      raise SocketWouldBlockError("No data to be received.") 
     144 
    323145    # Extract the data 
    324     msg_to_return = self.rsa_buffer_context[repr(socket)]['recv_buffer'][:bytes] 
    325  
    326     self.rsa_buffer_context[repr(socket)]['recv_lock'].acquire(True) 
     146    msg_to_return = self.rsa_buffer_context['recv_buffer'][:bytes] 
     147 
     148    self.rsa_buffer_context['recv_lock'].acquire(True) 
    327149    try: 
    328       self.rsa_buffer_context[repr(socket)]['recv_buffer'] = self.rsa_buffer_context[repr(socket)]['recv_buffer'][bytes:] 
     150      self.rsa_buffer_context['recv_buffer'] = self.rsa_buffer_context['recv_buffer'][bytes:] 
    329151    finally: 
    330       self.rsa_buffer_context[repr(socket)]['recv_lock'].release() 
    331  
    332      
     152      self.rsa_buffer_context['recv_lock'].release() 
     153 
    333154    return msg_to_return 
    334155 
    335156 
    336  
    337  
    338   def socket_close(self, sockobj): 
    339     """ 
    340     <Purpose> 
    341       Call the next layer of socket_close and remove the socket from  
    342       the active socket set. 
    343     """ 
    344  
    345     return_val = self.get_next_shim_layer().socket_close(sockobj) 
    346  
    347     try: 
    348       self.active_sock_set.remove(repr(sockobj)) 
    349     except: 
    350       pass 
    351  
    352     self.socket_closed_local.append(repr(sockobj)) 
    353  
    354     return return_val 
     157  def close(self): 
     158    self.close_lock.acquire(True) 
     159    result = not self.socket_closed_local 
     160    self.socket_closed_local = True 
     161    self.close_lock.release() 
     162 
     163    return result 
    355164 
    356165 
     
    359168  # Helper Threads 
    360169  # ========================================================= 
    361   def _sending_thread(self, sockobj): 
     170 
     171  def _sending_thread(self): 
    362172    """ 
    363173    <Purpose> 
     
    367177    """ 
    368178 
    369     def _sending_helper(): 
    370       # Continuously run this thread until socket is closed. 
    371       while repr(sockobj) in self.active_sock_set: 
    372         if self.rsa_buffer_context[repr(sockobj)]['send_buffer']: 
    373           msg = self.rsa_buffer_context[repr(sockobj)]['send_buffer'] 
     179    # Continuously run this thread until socket is closed. 
     180    while not self.send_closed_remote: 
     181      if self.rsa_buffer_context['send_buffer']: 
     182        msg = self.rsa_buffer_context['send_buffer'] 
     183        try: 
     184          data_sent = self.socket.send(msg) 
     185        except SocketWouldBlockError: 
     186          sleep(SLEEP_TIME) 
     187        except (SocketClosedLocal, SocketClosedRemote): 
     188          # Since the socket object is closed 
     189          break 
     190        else: 
     191          # Update the buffer to how much data was sent already. 
     192          self.rsa_buffer_context['send_lock'].acquire(True) 
    374193          try: 
    375             data_sent = self.get_next_shim_layer().socket_send(sockobj, msg) 
    376           except SocketWouldBlockError: 
    377             sleep(SLEEP_TIME) 
    378           except (SocketClosedLocal, SocketClosedRemote): 
    379             # Since the socket object is closed 
    380             try: 
    381               self.active_sock_set.remove(repr(sockobj)) 
    382             except: 
    383               pass                         
    384             break 
    385           else: 
    386             # Update the buffer to how much data was sent already. 
    387             self.rsa_buffer_context[repr(sockobj)]['send_lock'].acquire(True) 
    388             try: 
    389               self.rsa_buffer_context[repr(sockobj)]['send_buffer'] = self.rsa_buffer_context[repr(sockobj)]['send_buffer'][data_sent:] 
    390             finally: 
    391               self.rsa_buffer_context[repr(sockobj)]['send_lock'].release() 
    392         else: 
    393           # If we have an empty buffer, we just sleep. 
    394           sleep(SLEEP_TIME) 
    395  
    396     return _sending_helper 
    397  
    398  
    399  
    400  
    401   def _receiving_thread(self, sockobj): 
     194            self.rsa_buffer_context['send_buffer'] = self.rsa_buffer_context['send_buffer'][data_sent:] 
     195          finally: 
     196            self.rsa_buffer_context['send_lock'].release() 
     197      elif self.socket_closed_local and not self.rsa_buffer_context['send_buffer']: 
     198        break 
     199      else: 
     200        # If we have an empty buffer, we just sleep. 
     201        sleep(SLEEP_TIME) 
     202 
     203    self.close_lock.acquire(True) 
     204    self.send_closed_remote = True 
     205    closed_remote = self.recv_closed_remote 
     206    self.close_lock.release() 
     207 
     208    if closed_remote: 
     209      try: 
     210        self.socket.close() 
     211      except Exception: 
     212        pass 
     213 
     214 
     215  def _receiving_thread(self): 
    402216    """ 
    403217    <Purpose> 
     
    409223 
    410224    # Launch the decrypter. 
    411     createthread(self._decrypt_msg(sockobj)) 
    412  
    413     def _receiving_helper(): 
    414       # Keep receiving data and decrypting it. 
    415       while repr(sockobj) in self.active_sock_set: 
     225    createthread(self._decrypt_msg) 
     226 
     227    # Keep receiving data and decrypting it. 
     228    while not self.socket_closed_local: 
     229      try: 
     230        encrypt_msg = self.socket.recv(RECV_SIZE_BYTE) 
     231      except SocketWouldBlockError: 
     232        pass 
     233      except (SocketClosedLocal, SocketClosedRemote): 
     234        # Since the socket object is closed 
     235        self.send_closed_remote = True 
     236        break 
     237      else: 
     238        self.rsa_buffer_context['recv_encrypt_lock'].acquire(True) 
    416239        try: 
    417           encrypt_msg = self.get_next_shim_layer().socket_recv(sockobj, RECV_SIZE_BYTE) 
    418         except SocketWouldBlockError: 
    419           pass 
    420         except (SocketClosedLocal, SocketClosedRemote): 
    421           # Since the socket object is closed 
    422           try: 
    423             self.active_sock_set.remove(repr(sockobj)) 
    424           except: 
    425             pass 
    426           break 
    427         else: 
    428           self.rsa_buffer_context[repr(sockobj)]['recv_encrypt_lock'].acquire(True) 
    429           try: 
    430             self.rsa_buffer_context[repr(sockobj)]['recv_encrypted_buffer'] += encrypt_msg 
    431           finally: 
    432             self.rsa_buffer_context[repr(sockobj)]['recv_encrypt_lock'].release() 
    433  
    434     return _receiving_helper 
    435  
    436  
    437  
    438  
    439   def _decrypt_msg(self, sockobj): 
     240          self.rsa_buffer_context['recv_encrypted_buffer'] += encrypt_msg 
     241        finally: 
     242          self.rsa_buffer_context['recv_encrypt_lock'].release() 
     243 
     244    self.close_lock.acquire(True) 
     245    self.recv_closed_remote = True 
     246    closed_remote = self.send_closed_remote 
     247    self.close_lock.release() 
     248 
     249    if closed_remote: 
     250      try: 
     251        self.socket.close() 
     252      except Exception: 
     253        pass 
     254 
     255 
     256  def _decrypt_msg(self): 
    440257    """ 
    441258    This is a helper function that is used to decrypt the message in the 
     
    445262    """ 
    446263 
    447     def _decrypt_msg_helper(): 
    448       while True: 
    449         # First index of the character '\n'. It is used to determine 
    450         # upto which index in the recv_encrypt_buffer do we have the  
    451         # header for that particular packet that denotes the size of 
    452         # the message. 
    453         # Example of two packets hello and world would be: '5\nhello5\nworld' 
    454         header_index = self.rsa_buffer_context[repr(sockobj)]['recv_encrypted_buffer'].find('\n') 
    455            
    456         # We don't have a full encrypted message yet. 
    457         if header_index == -1: 
    458           if repr(sockobj) not in self.active_sock_set: 
    459             break 
    460           sleep(0.001) 
    461           continue 
    462  
    463         try: 
    464           message_length = int(self.rsa_buffer_context[repr(sockobj)]['recv_encrypted_buffer'][:header_index]) 
    465         except ValueError: 
    466           raise ShimInternalError("Unable to decrypt receiving message due to bad header in recv_encrypt_buffer") 
    467  
    468         # Get the actual message out. 
    469         packet_end_index = header_index + message_length + 1 
    470  
    471         # If we haven't received the entire package, then we can't decrypt. 
    472         if len(self.rsa_buffer_context[repr(sockobj)]['recv_encrypted_buffer']) < packet_end_index: 
    473           if repr(sockobj) not in self.active_socket_set: 
    474             break 
    475           sleep(0.001) 
    476           continue 
    477  
    478         total_data = self.rsa_buffer_context[repr(sockobj)]['recv_encrypted_buffer'][:packet_end_index] 
    479         encrypt_msg = total_data[header_index + 1 : ] 
     264    while True: 
     265      # First index of the character '\n'. It is used to determine 
     266      # upto which index in the recv_encrypt_buffer do we have the  
     267      # header for that particular packet that denotes the size of 
     268      # the message. 
     269      # Example of two packets hello and world would be: '5\nhello5\nworld' 
     270      header_index = self.rsa_buffer_context['recv_encrypted_buffer'].find('\n') 
    480271         
    481         self.rsa_buffer_context[repr(sockobj)]['recv_encrypt_lock'].acquire(True) 
    482         try: 
    483           self.rsa_buffer_context[repr(sockobj)]['recv_encrypted_buffer'] = self.rsa_buffer_context[repr(sockobj)]['recv_encrypted_buffer'][packet_end_index:] 
    484         finally: 
    485           self.rsa_buffer_context[repr(sockobj)]['recv_encrypt_lock'].release() 
    486  
    487         # decrypt the message. 
    488         try: 
    489           actual_msg = rsa_decrypt(encrypt_msg, self.rsa_key_dict[repr(sockobj)]['priv_key']) 
    490         except ValueError: 
    491           raise ShimInternalError("Invalid private key being used for decryption!") 
    492  
    493         # Add the unencrypted data to the recv buffer. 
    494         self.rsa_buffer_context[repr(sockobj)]['recv_lock'].acquire(True) 
    495         try: 
    496           self.rsa_buffer_context[repr(sockobj)]['recv_buffer'] += actual_msg 
    497         finally: 
    498           self.rsa_buffer_context[repr(sockobj)]['recv_lock'].release() 
    499  
     272      # We don't have a full encrypted message yet. 
     273      if header_index == -1: 
     274        if self.recv_closed_remote or self.socket_closed_local: 
     275          break 
     276        sleep(0.001) 
     277        continue 
     278 
     279      try: 
     280        message_length = int(self.rsa_buffer_context['recv_encrypted_buffer'][:header_index]) 
     281      except ValueError: 
     282        raise ShimInternalError("Unable to decrypt receiving message due to bad header in recv_encrypt_buffer") 
     283 
     284      # Get the actual message out. 
     285      packet_end_index = header_index + message_length + 1 
     286 
     287      # If we haven't received the entire package, then we can't decrypt. 
     288      if len(self.rsa_buffer_context['recv_encrypted_buffer']) < packet_end_index: 
     289        if self.recv_closed_remote or self.socket_closed_local: 
     290          break 
     291        sleep(0.001) 
     292        continue 
     293 
     294      total_data = self.rsa_buffer_context['recv_encrypted_buffer'][:packet_end_index] 
     295      encrypt_msg = total_data[header_index + 1 : ] 
    500296       
    501     return _decrypt_msg_helper 
    502  
     297      self.rsa_buffer_context['recv_encrypt_lock'].acquire(True) 
     298      try: 
     299        self.rsa_buffer_context['recv_encrypted_buffer'] = self.rsa_buffer_context['recv_encrypted_buffer'][packet_end_index:] 
     300      finally: 
     301        self.rsa_buffer_context['recv_encrypt_lock'].release() 
     302 
     303      # decrypt the message. 
     304      try: 
     305        actual_msg = rsa_decrypt(encrypt_msg, self.rsa_key_dict['priv_key']) 
     306      except ValueError: 
     307        raise ShimInternalError("Invalid private key being used for decryption!") 
     308 
     309      # Add the unencrypted data to the recv buffer. 
     310      self.rsa_buffer_context['recv_lock'].acquire(True) 
     311      try: 
     312        self.rsa_buffer_context['recv_buffer'] += actual_msg 
     313      finally: 
     314        self.rsa_buffer_context['recv_lock'].release() 
     315 
     316    self.decrypt_done = True 
     317 
     318 
     319 
     320 
     321class RSAShim(BaseShim): 
     322 
     323  def __init__(self, shim_stack = ShimStack(), optional_args = None): 
     324    """ 
     325    <Purpose> 
     326      Initialize the RSAShim. 
     327 
     328    <Arguments> 
     329      shim_stack - the shim stack underneath us. 
     330 
     331      optional_args - The optional args (if provided) will be used to 
     332        encrypt and decrypt data. A new key will not be generated. Note 
     333        that if optional_args is provided then it must be provided for  
     334        both the server side and client side, otherwise they won't be 
     335        able to communicate. 
     336 
     337    <Side Effects> 
     338      None 
     339 
     340    <Exceptions> 
     341      ShimInternalError raised if the optional args provided is not of 
     342      the proper format. 
     343 
     344    <Return> 
     345      None 
     346    """ 
     347     
     348    BaseShim.__init__(self, shim_stack, optional_args) 
     349 
     350  
     351  # ======================================== 
     352  # TCP section of RSA Shim 
     353  # ======================================== 
     354  def openconnection(self, destip, destport, localip, localport, timeout): 
     355    """ 
     356    <Purpose> 
     357      Create a connection and initiate the handshake with the server. 
     358      Make sure that we have the same method of communication. 
     359 
     360    <Arguments> 
     361      Same arguments as Repy V2 Api for openconnection. 
     362 
     363    <Side Effects> 
     364      Some messages are sent back and forth. 
     365 
     366    <Exceptions> 
     367      Same exceptions as Repy V2 Api for openconnection. Note that 
     368      a ConnectionRefusedError is raised if the handhake fails with 
     369      the server. 
     370 
     371    <Return> 
     372      A socket like object. 
     373    """ 
     374 
     375    # Open a connection by calling the next layer of shim. 
     376    sockobj = self.get_next_shim_layer().openconnection(destip, destport, localip, localport, timeout) 
     377 
     378    # Generate a new set of pubkey/privkey and send the pub 
     379    # key back to the server to receive the actual key. 
     380    (temp_pub, temp_priv) = rsa_gen_pubpriv_keys(BITSIZE) 
     381     
     382    # Greet the server and send it a temporary pubkey and wait 
     383    # for a response. 
     384    session_sendmessage(sockobj, GREET_TAG + str(temp_pub)) 
     385    encrypted_response = session_recvmessage(sockobj) 
     386 
     387    response = rsa_decrypt(encrypted_response, temp_priv) 
     388     
     389    if response.startswith(NEW_SHARE_TAG): 
     390      key_pair = response.split(NEW_SHARE_TAG)[1] 
     391 
     392      (pub_key, priv_key) = key_pair.split(':::') 
     393 
     394      rsa_key_dict = {} 
     395      rsa_key_dict['pub_key'] = eval(pub_key) 
     396      rsa_key_dict['priv_key'] = eval(priv_key) 
     397 
     398      rsa_buffer_context = {'send_buffer' : '', 
     399                            'recv_encrypted_buffer' : '', 
     400                            'recv_buffer' : '', 
     401                            'send_lock' : createlock(), 
     402                            'recv_lock' : createlock(), 
     403                            'recv_encrypt_lock' : createlock()} 
     404    else: 
     405      raise ConnectionRefusedError("Unable to complete handshake with server and " + 
     406                                   "agree on RSA key.") 
     407 
     408    return RSASocket(sockobj, rsa_key_dict, rsa_buffer_context) 
     409 
     410 
     411  def tcpserversocket_getconnection(self, tcpserversocket): 
     412    """ 
     413    <Purpose> 
     414      Accept a connection from the client. Complete a handshake 
     415      to make sure that both the server and client have the same 
     416      pub/priv key. 
     417    
     418    <Arguments> 
     419      Same arguments as Repy V2 Api for tcpserver.getconnection() 
     420 
     421    <Side Effects> 
     422      Some messages are sent back and forth. Some RSA keys are generated 
     423      so things might slow down. 
     424 
     425    <Return> 
     426      A tuple of remoteip, remoteport and socket like object. 
     427    """ 
     428 
     429    # Call the next layer of socket to get a connection. 
     430    (remoteip, remoteport, sockobj) = self.get_next_shim_layer().tcpserversocket_getconnection(tcpserversocket) 
     431 
     432    # Try to get the initial greeting from the connection. 
     433    try: 
     434      initial_msg = session_recvmessage(sockobj) 
     435    except (ValueError, SessionEOF): 
     436      raise SocketWouldBlockError("No connection available right now.") 
     437 
     438    # If we get a greeting tag then we send back to the client a new set 
     439    # of key that will be used to do all the communication. 
     440    if initial_msg.startswith(GREET_TAG): 
     441      # Extract the pubkey and convert it to dict. 
     442      client_pubkey = eval(initial_msg.split(GREET_TAG)[1]) 
     443 
     444      # Generate new key. 
     445      (pub_key, priv_key) = rsa_gen_pubpriv_keys(BITSIZE) 
     446 
     447      rsa_key_dict = {} 
     448      rsa_key_dict['pub_key'] = pub_key 
     449      rsa_key_dict['priv_key'] = priv_key 
     450 
     451      rsa_buffer_context = {'send_buffer' : '', 
     452                            'recv_encrypted_buffer' : '', 
     453                            'recv_buffer' : '', 
     454                            'send_lock' : createlock(), 
     455                            'recv_lock' : createlock(), 
     456                            'recv_encrypt_lock' : createlock()} 
     457 
     458      # Send back the new set of keys, encrypted with the pubkey 
     459      # provided by the client initially. 
     460      new_msg = NEW_SHARE_TAG + str(pub_key) + ':::' + str(priv_key) 
     461      session_sendmessage(sockobj, rsa_encrypt(new_msg, client_pubkey)) 
     462 
     463    else: 
     464      raise ConnectionRefusedError("Unable to complete handshake with server and " + 
     465                                   "agree on RSA key.") 
     466 
     467    return (remoteip, remoteport, RSASocket(sockobj, rsa_key_dict, rsa_buffer_context)) 
     468 
     469 
     470  def socket_send(self, socket, msg): 
     471    return socket.send(msg) 
     472 
     473 
     474  def socket_recv(self, socket, bytes): 
     475    return socket.recv(bytes) 
     476 
     477 
     478  def socket_close(self, socket): 
     479    return socket.close() 
    503480 
    504481 
     
    516493    shim_stack_copy = self.shim_context['shim_stack'].copy() 
    517494    optional_args = self.shim_context['optional_args'] 
     495 
    518496    return RSAShim(shim_stack_copy, optional_args) 
    519497