diff --git a/.gitignore b/.gitignore index c3669a6..9067def 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,4 @@ venv/* streamedrequests/__pycache__/* +testdata/* +.vscode/* \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json deleted file mode 100644 index 53d8ec2..0000000 --- a/.vscode/settings.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "python.pythonPath": "venv/bin/python3.7" -} \ No newline at end of file diff --git a/streamedrequests/dodownload.py b/streamedrequests/dodownload.py index 20ca830..baa9606 100644 --- a/streamedrequests/dodownload.py +++ b/streamedrequests/dodownload.py @@ -18,11 +18,11 @@ import threading def __run_callback(data, sync, callback=None): - if callback is None: + if callback is None: # Do nothing if there is no callback return - if sync: + if sync: # If synchronous (default), run callback normally callback(data) - else: + else: # If async, spawn a new thread (not good for CPU-bound cases) threading.Thread(target=callback, args=(data,)).start() def __do_download(req, max_size, chunk_size, callback, sync): diff --git a/streamedrequests/get.py b/streamedrequests/get.py index cd7440d..6b6aead 100644 --- a/streamedrequests/get.py +++ b/streamedrequests/get.py @@ -23,7 +23,8 @@ def get(url, query_parameters=None, request_headers=None, sync=True, chunk_count = responsesize.SizeValidator(max_size) # Class to verify if the stream is staying within the max_size - timeouts = setuptimeout.__setup_timeout(connect_timeout, stream_timeout) + timeouts = setuptimeout.__setup_timeout(connect_timeout, stream_timeout) # Get a timeout int or tuple + # Requests uses separate value for connect vs stream timeout req = requests.get(url, params=query_parameters, headers=request_headers, timeout=timeouts, stream=True, allow_redirects=allow_redirects, proxies=proxy) diff --git a/streamedrequests/post.py b/streamedrequests/post.py index 3ccad81..687e26c 100644 --- a/streamedrequests/post.py +++ b/streamedrequests/post.py @@ -16,7 +16,15 @@ along with this program. If not, see . ''' import requests +from . import exceptions, responsesize, dodownload, setuptimeout def post(url, post_data=None, request_headers=None, sync=True, max_size=0, chunk_size=1000, connect_timeout=60, stream_timeout=0, proxy={}, callback=None, allow_redirects=True): - return \ No newline at end of file + chunk_count = responsesize.SizeValidator(max_size) # Class to verify if the stream is staying within the max_size + + timeouts = setuptimeout.__setup_timeout(connect_timeout, stream_timeout) + + req = requests.post(url, data=post_data, headers=request_headers, + timeout=timeouts, stream=True, allow_redirects=allow_redirects, proxies=proxy) + + return dodownload.__do_download(req, max_size, chunk_size, callback, sync) \ No newline at end of file diff --git a/tests/test_basic.py b/tests/test.py similarity index 61% rename from tests/test_basic.py rename to tests/test.py index 2163417..77af4a1 100644 --- a/tests/test_basic.py +++ b/tests/test.py @@ -15,47 +15,67 @@ You should have received a copy of the GNU General Public License along with this program. If not, see . ''' -import sys, os, unittest, threading -from http.server import HTTPServer, SimpleHTTPRequestHandler +import sys, os, unittest, threading, atexit +from http.server import HTTPServer, SimpleHTTPRequestHandler, BaseHTTPRequestHandler import requests sys.path.append(os.path.dirname(os.path.realpath(__file__)) + "/../") import streamedrequests +test_data_1 = 'test '*1000 + '\ntwo\n' +test_data = test_data_1 + 'test2'*1000 + +class S(BaseHTTPRequestHandler): + def POST(self): + # Doesn't do anything with posted data + self._set_headers() + self.wfile.write("

POST!

") + def get_test_id(): return str(uuid.uuid4()) + '.dat' +def setup(): + if not os.path.exists('testdata'): + os.mkdir('testdata') + with open('index.html', 'w') as testfile: + testfile.write(test_data) + def run(server_class=HTTPServer, handler_class=SimpleHTTPRequestHandler): server_address = ('127.0.0.1', 8000) httpd = server_class(server_address, handler_class) httpd.serve_forever() +def run_post(server_class=S, handler_class=BaseHTTPRequestHandler): + server_address = ('127.0.0.1', 8001) + httpd = server_class(server_address, handler_class) + httpd.serve_forever() + def _test_callback(text): - return - #print('got', text) + print('got', text) class TestInit(unittest.TestCase): def test_requests(self): - requests.get('http://127.0.0.1:8000/') + if "test" not in requests.get('http://127.0.0.1:8000/').text: + raise ValueError("test not found in test data") def test_basic(self): streamedrequests.get('http://127.0.0.1:8000/') def test_callback(self): - pass streamedrequests.get('http://127.0.0.1:8000/', chunk_size=1, callback=_test_callback) def test_async(self): streamedrequests.get('http://127.0.0.1:8000/', chunk_size=1, callback=_test_callback, sync=False) def test_zero_chunk_size(self): - try: + with self.assertRaises(ValueError): streamedrequests.get('http://127.0.0.1:8000/', chunk_size=0) - except ValueError: - pass - else: - self.assertTrue(failUnless) + + def test_post(self): + streamedrequests.post('http://127.0.0.1:8000/') +setup() threading.Thread(target=run, daemon=True).start() +threading.Thread(target=run_post, daemon=True).start() unittest.main() \ No newline at end of file