work on tests

This commit is contained in:
Kevin Froman 2019-07-07 19:16:41 -05:00
parent cb8353336e
commit b693a23a6f
6 changed files with 47 additions and 19 deletions

2
.gitignore vendored
View File

@ -1,2 +1,4 @@
venv/* venv/*
streamedrequests/__pycache__/* streamedrequests/__pycache__/*
testdata/*
.vscode/*

View File

@ -1,3 +0,0 @@
{
"python.pythonPath": "venv/bin/python3.7"
}

View File

@ -18,11 +18,11 @@
import threading import threading
def __run_callback(data, sync, callback=None): def __run_callback(data, sync, callback=None):
if callback is None: if callback is None: # Do nothing if there is no callback
return return
if sync: if sync: # If synchronous (default), run callback normally
callback(data) callback(data)
else: else: # If async, spawn a new thread (not good for CPU-bound cases)
threading.Thread(target=callback, args=(data,)).start() threading.Thread(target=callback, args=(data,)).start()
def __do_download(req, max_size, chunk_size, callback, sync): def __do_download(req, max_size, chunk_size, callback, sync):

View File

@ -23,7 +23,8 @@ def get(url, query_parameters=None, request_headers=None, sync=True,
chunk_count = responsesize.SizeValidator(max_size) # Class to verify if the stream is staying within the max_size chunk_count = responsesize.SizeValidator(max_size) # Class to verify if the stream is staying within the max_size
timeouts = setuptimeout.__setup_timeout(connect_timeout, stream_timeout) timeouts = setuptimeout.__setup_timeout(connect_timeout, stream_timeout) # Get a timeout int or tuple
# Requests uses separate value for connect vs stream timeout
req = requests.get(url, params=query_parameters, headers=request_headers, req = requests.get(url, params=query_parameters, headers=request_headers,
timeout=timeouts, stream=True, allow_redirects=allow_redirects, proxies=proxy) timeout=timeouts, stream=True, allow_redirects=allow_redirects, proxies=proxy)

View File

@ -16,7 +16,15 @@
along with this program. If not, see <https://www.gnu.org/licenses/>. along with this program. If not, see <https://www.gnu.org/licenses/>.
''' '''
import requests import requests
from . import exceptions, responsesize, dodownload, setuptimeout
def post(url, post_data=None, request_headers=None, sync=True, def post(url, post_data=None, request_headers=None, sync=True,
max_size=0, chunk_size=1000, connect_timeout=60, stream_timeout=0, max_size=0, chunk_size=1000, connect_timeout=60, stream_timeout=0,
proxy={}, callback=None, allow_redirects=True): proxy={}, callback=None, allow_redirects=True):
return chunk_count = responsesize.SizeValidator(max_size) # Class to verify if the stream is staying within the max_size
timeouts = setuptimeout.__setup_timeout(connect_timeout, stream_timeout)
req = requests.post(url, data=post_data, headers=request_headers,
timeout=timeouts, stream=True, allow_redirects=allow_redirects, proxies=proxy)
return dodownload.__do_download(req, max_size, chunk_size, callback, sync)

View File

@ -15,47 +15,67 @@
You should have received a copy of the GNU General Public License You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>. along with this program. If not, see <https://www.gnu.org/licenses/>.
''' '''
import sys, os, unittest, threading import sys, os, unittest, threading, atexit
from http.server import HTTPServer, SimpleHTTPRequestHandler from http.server import HTTPServer, SimpleHTTPRequestHandler, BaseHTTPRequestHandler
import requests import requests
sys.path.append(os.path.dirname(os.path.realpath(__file__)) + "/../") sys.path.append(os.path.dirname(os.path.realpath(__file__)) + "/../")
import streamedrequests import streamedrequests
test_data_1 = 'test '*1000 + '\ntwo\n'
test_data = test_data_1 + 'test2'*1000
class S(BaseHTTPRequestHandler):
def POST(self):
# Doesn't do anything with posted data
self._set_headers()
self.wfile.write("<html><body><h1>POST!</h1></body></html>")
def get_test_id(): def get_test_id():
return str(uuid.uuid4()) + '.dat' return str(uuid.uuid4()) + '.dat'
def setup():
if not os.path.exists('testdata'):
os.mkdir('testdata')
with open('index.html', 'w') as testfile:
testfile.write(test_data)
def run(server_class=HTTPServer, handler_class=SimpleHTTPRequestHandler): def run(server_class=HTTPServer, handler_class=SimpleHTTPRequestHandler):
server_address = ('127.0.0.1', 8000) server_address = ('127.0.0.1', 8000)
httpd = server_class(server_address, handler_class) httpd = server_class(server_address, handler_class)
httpd.serve_forever() httpd.serve_forever()
def run_post(server_class=S, handler_class=BaseHTTPRequestHandler):
server_address = ('127.0.0.1', 8001)
httpd = server_class(server_address, handler_class)
httpd.serve_forever()
def _test_callback(text): def _test_callback(text):
return print('got', text)
#print('got', text)
class TestInit(unittest.TestCase): class TestInit(unittest.TestCase):
def test_requests(self): def test_requests(self):
requests.get('http://127.0.0.1:8000/') if "test" not in requests.get('http://127.0.0.1:8000/').text:
raise ValueError("test not found in test data")
def test_basic(self): def test_basic(self):
streamedrequests.get('http://127.0.0.1:8000/') streamedrequests.get('http://127.0.0.1:8000/')
def test_callback(self): def test_callback(self):
pass
streamedrequests.get('http://127.0.0.1:8000/', chunk_size=1, callback=_test_callback) streamedrequests.get('http://127.0.0.1:8000/', chunk_size=1, callback=_test_callback)
def test_async(self): def test_async(self):
streamedrequests.get('http://127.0.0.1:8000/', chunk_size=1, callback=_test_callback, sync=False) streamedrequests.get('http://127.0.0.1:8000/', chunk_size=1, callback=_test_callback, sync=False)
def test_zero_chunk_size(self): def test_zero_chunk_size(self):
try: with self.assertRaises(ValueError):
streamedrequests.get('http://127.0.0.1:8000/', chunk_size=0) streamedrequests.get('http://127.0.0.1:8000/', chunk_size=0)
except ValueError:
pass
else:
self.assertTrue(failUnless)
def test_post(self):
streamedrequests.post('http://127.0.0.1:8000/')
setup()
threading.Thread(target=run, daemon=True).start() threading.Thread(target=run, daemon=True).start()
threading.Thread(target=run_post, daemon=True).start()
unittest.main() unittest.main()