diff --git a/flask_annex/base.py b/flask_annex/base.py index ac877a9..85c96cc 100644 --- a/flask_annex/base.py +++ b/flask_annex/base.py @@ -26,5 +26,5 @@ def save_file(self, key, in_file): def send_file(self, key): raise NotImplementedError() - def send_upload_info(self, key): + def get_upload_info(self, key): raise NotImplementedError() diff --git a/flask_annex/file.py b/flask_annex/file.py index 6ed8cf0..4854b20 100644 --- a/flask_annex/file.py +++ b/flask_annex/file.py @@ -99,5 +99,5 @@ def send_file(self, key): attachment_filename=os.path.basename(key), ) - def send_upload_info(self, key): + def get_upload_info(self, key): raise NotImplementedError("file annex does not support upload info") diff --git a/flask_annex/s3.py b/flask_annex/s3.py index 0cb450d..e833455 100644 --- a/flask_annex/s3.py +++ b/flask_annex/s3.py @@ -10,6 +10,8 @@ DEFAULT_EXPIRES_IN = 300 +MISSING = object() + # ----------------------------------------------------------------------------- @@ -21,6 +23,7 @@ def __init__( access_key_id=None, secret_access_key=None, expires_in=DEFAULT_EXPIRES_IN, + max_content_length=MISSING, ): self._client = boto3.client( 's3', @@ -31,6 +34,7 @@ def __init__( self._bucket_name = bucket_name self._expires_in = expires_in + self._max_content_length = max_content_length def delete(self, key): self._client.delete_object(Bucket=self._bucket_name, Key=key) @@ -91,7 +95,7 @@ def send_file(self, key): ) return flask.redirect(url) - def send_upload_info(self, key): + def get_upload_info(self, key): fields = {} conditions = [] @@ -99,7 +103,10 @@ def send_upload_info(self, key): if content_type: fields['Content-Type'] = content_type - max_content_length = flask.current_app.config['MAX_CONTENT_LENGTH'] + if self._max_content_length is not MISSING: + max_content_length = self._max_content_length + else: + max_content_length = flask.current_app.config['MAX_CONTENT_LENGTH'] if max_content_length is not None: conditions.append( ('content-length-range', 0, max_content_length), @@ -117,8 +124,8 @@ def send_upload_info(self, key): ExpiresIn=self._expires_in, ) - return flask.jsonify( - method='POST', - url=post_info['url'], - data=post_info['fields'], - ) + return { + 'method': 'POST', + 'url': post_info['url'], + 'data': post_info['fields'], + } diff --git a/tests/helpers.py b/tests/helpers.py index 9c90042..1f61f73 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -1,6 +1,7 @@ from io import BytesIO import json +import flask import pytest # ----------------------------------------------------------------------------- @@ -14,8 +15,8 @@ def assert_key_value(annex, key, value): assert out_file.read() == value -def get_upload_info(client, key): - response = client.get('/upload_info/{}'.format(key)) +def get_upload_info(client, key, **kwargs): + response = client.get('/upload_info/{}'.format(key), **kwargs) return json.loads(response.get_data(as_text=True)) @@ -31,13 +32,30 @@ def annex(self, annex_base): @pytest.fixture(autouse=True) def routes(self, app, annex): - @app.route('/file/') + @app.route('/files/', methods=('GET', 'PUT')) def file(key): + if flask.request.method != 'GET': + raise NotImplementedError() return annex.send_file(key) @app.route('/upload_info/') def upload_info(key): - return annex.send_upload_info(key) + try: + upload_info = annex.get_upload_info(key) + except NotImplementedError: + upload_info = { + 'method': 'PUT', + 'url': flask.url_for( + 'file', key=key, _method='PUT', _external=True, + ), + 'headers': { + 'Authorization': flask.request.headers.get( + 'Authorization', + ), + }, + } + + return flask.jsonify(upload_info) def test_get_file(self, annex): assert_key_value(annex, 'foo/bar.txt', b'1\n') diff --git a/tests/test_file.py b/tests/test_file.py index 2123e25..dc5128e 100644 --- a/tests/test_file.py +++ b/tests/test_file.py @@ -4,7 +4,7 @@ from flask_annex import Annex -from helpers import AbstractTestAnnex, assert_key_value +from helpers import AbstractTestAnnex, assert_key_value, get_upload_info # ----------------------------------------------------------------------------- @@ -27,14 +27,32 @@ def test_save_file_existing_dir(self, annex): assert_key_value(annex, 'foo/qux.txt', b'6\n') def test_send_file(self, client): - response = client.get('/file/foo/baz.json') + response = client.get('/files/foo/baz.json') assert response.status_code == 200 assert response.mimetype == 'application/json' assert 'attachment' in response.headers['Content-Disposition'] - def test_send_upload_info(self, annex): - with pytest.raises(NotImplementedError): - annex.send_upload_info('foo/qux.txt') + def test_get_upload_info(self, client): + upload_info = get_upload_info(client, 'foo/qux.txt') + assert upload_info == { + 'method': 'PUT', + 'url': 'http://localhost/files/foo/qux.txt', + 'headers': { + 'Authorization': None, + }, + } + + def test_get_upload_info_authorized(self, client): + upload_info = get_upload_info(client, 'foo/qux.txt', headers={ + 'Authorization': 'Bearer foo', + }) + assert upload_info == { + 'method': 'PUT', + 'url': 'http://localhost/files/foo/qux.txt', + 'headers': { + 'Authorization': 'Bearer foo', + }, + } class TestFileAnnexFromEnv(TestFileAnnex): diff --git a/tests/test_s3.py b/tests/test_s3.py index 3410424..67e5a66 100644 --- a/tests/test_s3.py +++ b/tests/test_s3.py @@ -62,7 +62,7 @@ def test_save_file_unknown_type(self, annex): assert_key_value(annex, 'foo/qux', b'6\n') def test_send_file(self, client): - response = client.get('/file/foo/baz.json') + response = client.get('/files/foo/baz.json') assert response.status_code == 302 s3_url = response.headers['Location'] @@ -90,12 +90,21 @@ def test_get_upload_info(self, client): assert get_condition(conditions, 'key') == 'foo/qux.txt' assert get_condition(conditions, 'Content-Type') == 'text/plain' + self.assert_default_content_length_range(conditions) + + def assert_default_content_length_range(self, conditions): + with pytest.raises(KeyError): + get_condition(conditions, 'content-length-range') + def test_get_upload_info_max_content_length(self, app, client): app.config['MAX_CONTENT_LENGTH'] = 100 upload_info = get_upload_info(client, 'foo/qux.txt') conditions = get_policy(upload_info)['conditions'] + self.assert_app_config_content_length_range(conditions) + + def assert_app_config_content_length_range(self, conditions): assert get_condition(conditions, 'content-length-range') == [0, 100] @@ -107,3 +116,15 @@ def annex_base(self, monkeypatch, bucket_name): monkeypatch.setenv('FLASK_ANNEX_S3_REGION', 'us-east-1') return Annex.from_env('FLASK_ANNEX') + + +class TestS3AnnexMaxContentLength(TestS3Annex): + @pytest.fixture + def annex_base(self, bucket_name): + return Annex('s3', bucket_name, max_content_length=1000) + + def assert_default_content_length_range(self, conditions): + assert get_condition(conditions, 'content-length-range') == [0, 1000] + + def assert_app_config_content_length_range(self, conditions): + assert get_condition(conditions, 'content-length-range') == [0, 1000]