Source code for wpull.proxy.server

# encoding=utf-8
'''Proxy Tools'''
import enum
import gettext
import logging
import ssl
import os
import socket

import asyncio

import errno

from wpull.application.hook import HookableMixin, HookDisconnected
from wpull.backport.logging import BraceMessage as __
from wpull.body import Body
from wpull.errors import ProtocolError, NetworkError
from wpull.protocol.http.client import Client, Session
from wpull.protocol.http.request import Request
import wpull.util


_ = gettext.gettext
_logger = logging.getLogger(__name__)


[docs]class HTTPProxyServer(HookableMixin): '''HTTP proxy server for use with man-in-the-middle recording. This function is meant to be used as a callback:: asyncio.start_server(HTTPProxyServer(HTTPClient)) Args: http_client (:class:`.http.client.Client`): The HTTP client. Attributes: request_callback: A callback function that accepts a Request. pre_response_callback: A callback function that accepts a Request and Response response_callback: A callback function that accepts a Request and Response '''
[docs] class Event(enum.Enum): begin_session = 'begin_session' end_session = 'end_session'
def __init__(self, http_client: Client): super().__init__() self._http_client = http_client self.event_dispatcher.register(self.Event.begin_session) self.event_dispatcher.register(self.Event.end_session) @asyncio.coroutine def __call__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter): '''Handle a request Coroutine.''' _logger.debug('New proxy connection.') session = self._new_session(reader, writer) self.event_dispatcher.notify(self.Event.begin_session, session) is_error = False try: yield from session() except Exception as error: if not isinstance(error, StopIteration): error = True if isinstance(error, (ConnectionAbortedError, ConnectionResetError)): # Client using the proxy has closed the connection _logger.debug('Proxy error', exc_info=True) else: _logger.exception('Proxy error') writer.close() else: raise finally: self.event_dispatcher.notify(self.Event.end_session, session, error=is_error) writer.close() _logger.debug('Proxy connection closed.') def _new_session(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> 'HTTPProxySession': return HTTPProxySession(self._http_client, reader, writer)
[docs]class HTTPProxySession(HookableMixin):
[docs] class Event(enum.Enum): client_request = 'client_request' server_begin_response = 'server_begin_response' server_end_response = 'server_end_response' server_response_error = 'server_response_error'
def __init__(self, http_client: Client, reader: asyncio.StreamReader, writer: asyncio.StreamWriter): super().__init__() self._http_client = http_client self._reader = self._original_reader = reader self._writer = self._original_writer = writer self._is_tunnel = False self._is_ssl_tunnel = False self._cert_filename = wpull.util.get_package_filename('proxy/proxy.crt') self._key_filename = wpull.util.get_package_filename('proxy/proxy.key') assert os.path.isfile(self._cert_filename), self._cert_filename assert os.path.isfile(self._key_filename), self._key_filename self.hook_dispatcher.register(self.Event.client_request) self.hook_dispatcher.register(self.Event.server_begin_response) self.event_dispatcher.register(self.Event.server_end_response) self.event_dispatcher.register(self.Event.server_response_error) @asyncio.coroutine def __call__(self): '''Process a connection session.''' _logger.debug('Begin session.') while True: request = yield from self._read_request_header() if not request: return yield from self._process_request(request) @asyncio.coroutine def _process_request(self, request: Request): _logger.debug(__('Got request {0}', request)) if request.method == 'CONNECT': yield from self._start_connect_tunnel() return if self._is_ssl_tunnel and request.url.startswith('http://'): # Since we are spying under a SSL tunnel, assume processed requests # are SSL request.url = request.url.replace('http://', 'https://', 1) if 'Upgrade' in request.fields.get('Connection', ''): _logger.warning(__( _('Connection Upgrade not supported for {}'), request.url )) self._reject_request('Upgrade not supported') return _logger.debug('Begin response.') try: action = self.hook_dispatcher.call(self.Event.client_request, request) except HookDisconnected: pass else: if not action: _logger.debug('Proxy force reject request') self._reject_request() return with self._http_client.session() as session: if 'Content-Length' in request.fields: request.body = self._reader try: response = yield from session.start(request) except NetworkError as error: _logger.debug('Upstream error', exc_info=True) self._write_error_response() self.event_dispatcher.notify(self.Event.server_response_error, error) return response.body = Body() try: action = self.hook_dispatcher.call(self.Event.server_begin_response, response) except HookDisconnected: pass else: if not action: _logger.debug('Proxy force reject request via response') self._reject_request() return try: self._writer.write(response.to_bytes()) yield from self._writer.drain() session.event_dispatcher.add_listener( Session.Event.response_data, self._writer.write ) yield from session.download(file=response.body, raw=True) yield from self._writer.drain() except NetworkError as error: _logger.debug('Upstream error', exc_info=True) self.event_dispatcher.notify(self.Event.server_response_error, error) raise self.event_dispatcher.notify(self.Event.server_end_response, response) _logger.debug('Response done.') @asyncio.coroutine def _start_connect_tunnel(self): if self._is_tunnel: self._reject_request('Cannot CONNECT within CONNECT') return self._is_tunnel = True original_socket = yield from self._detach_socket_and_start_tunnel() is_ssl = yield from self._is_client_request_ssl(original_socket) if is_ssl: _logger.debug('Tunneling as SSL') yield from self._start_ssl_tunnel() else: yield from self._rewrap_socket(original_socket) @classmethod @asyncio.coroutine def _is_client_request_ssl(cls, socket_: socket.socket) -> bool: while True: original_timeout = socket_.gettimeout() socket_.setblocking(False) try: data = socket_.recv(3, socket.MSG_PEEK) except OSError as error: if error.errno in (errno.EWOULDBLOCK, errno.EAGAIN): yield from asyncio.sleep(0.01) else: raise else: break finally: socket_.settimeout(original_timeout) _logger.debug('peeked data %s', data) if all(ord('A') <= char_code <= ord('Z') for char_code in data): return False else: return True @asyncio.coroutine def _start_ssl_tunnel(self): '''Start SSL protocol on the socket.''' self._is_ssl_tunnel = True ssl_socket = yield from self._start_ssl_handshake() yield from self._rewrap_socket(ssl_socket) @asyncio.coroutine def _detach_socket_and_start_tunnel(self) -> socket.socket: socket_ = self._writer.get_extra_info('socket') try: asyncio.get_event_loop().remove_reader(socket_.fileno()) except ValueError as error: raise ConnectionAbortedError() from error self._writer.write(b'HTTP/1.1 200 Connection established\r\n\r\n') yield from self._writer.drain() try: asyncio.get_event_loop().remove_writer(socket_.fileno()) except ValueError as error: raise ConnectionAbortedError() from error return socket_ @asyncio.coroutine def _start_ssl_handshake(self): socket_ = self._writer.get_extra_info('socket') ssl_socket = ssl.wrap_socket( socket_, server_side=True, certfile=self._cert_filename, keyfile=self._key_filename, do_handshake_on_connect=False ) # FIXME: this isn't how to START TLS for dummy in range(1200): try: ssl_socket.do_handshake() break except ssl.SSLError as error: if error.errno in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE): _logger.debug('Do handshake %s', error) yield from asyncio.sleep(0.05) else: raise else: _logger.error(_('Unable to handshake.')) ssl_socket.close() self._reject_request('Could not start TLS') raise ConnectionAbortedError('Could not start TLS') return ssl_socket @asyncio.coroutine def _rewrap_socket(self, new_socket): loop = asyncio.get_event_loop() reader = asyncio.StreamReader(loop=loop) protocol = asyncio.StreamReaderProtocol(reader, loop=loop) transport, dummy = yield from loop.create_connection( lambda: protocol, sock=new_socket) writer = asyncio.StreamWriter(transport, protocol, reader, loop) self._reader = reader self._writer = writer @asyncio.coroutine def _read_request_header(self) -> Request: request = Request() for dummy in range(100): line = yield from self._reader.readline() _logger.debug(__('Got line {0}', line)) if line[-1:] != b'\n': return if not line.strip(): break request.parse(line) else: raise ProtocolError('Request has too many headers.') return request def _reject_request(self, message='Gateway Request Not Allowed'): '''Send HTTP 501 and close the connection.''' self._write_error_response(501, message) def _write_error_response(self, code=502, message='Bad Gateway Upstream Error'): self._writer.write( 'HTTP/1.1 {} {}\r\n'.format(code, message).encode('ascii', 'replace') ) self._writer.write(b'\r\n') self._writer.close()
def _main_test(): from wpull.protocol.http.client import Client logging.basicConfig(level=logging.DEBUG) http_client = Client() proxy = HTTPProxyServer(http_client) asyncio.get_event_loop().run_until_complete(asyncio.start_server(proxy, port=8888)) asyncio.get_event_loop().run_forever() if __name__ == '__main__': _main_test()