Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ Release 1.3
**Not released yet**

* feat: Nicer error messages when reading from a closed ``MultipartPart``.
* feat: Support custom `MultipartSegment` subclasses to be used and emitted by
`PushMultipartParser`. However, the API between parser and segment is not
stable yet. Overriding any of the ``_on_*`` methods may break during releases.

Release 1.2
===========
Expand Down
103 changes: 57 additions & 46 deletions multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import re
from io import BytesIO
from typing import Iterator, Union, Optional, Tuple, List
from typing import Generic, Iterator, Type, TypeVar, Union, Optional, Tuple, List
from urllib.parse import parse_qs
from wsgiref.headers import Headers
from collections.abc import MutableMapping as DictMixin
Expand Down Expand Up @@ -280,8 +280,10 @@ def parse_options_header(header, options=None, unquote=header_unquote):
_BODY = "BODY"
_COMPLETE = "END"

t_segment = TypeVar('SegmentType', bound="MultipartSegment")

class PushMultipartParser(Generic[t_segment]):

class PushMultipartParser:
def __init__(
self,
boundary: Union[str, bytes],
Expand All @@ -292,6 +294,7 @@ def __init__(
max_segment_count=inf, # unlimited
header_charset="utf8",
strict=False,
segment_class: Optional[Type[t_segment]] = None,
):
"""A push-based (incremental, non-blocking) parser for multipart/form-data.

Expand All @@ -311,6 +314,8 @@ def __init__(
:param max_segment_count: Maximum number of segments.
:param header_charset: Charset for header names and values.
:param strict: Enables additional format and sanity checks.

:param segment_class: Class for emitted segments, defaults to `MultipartSegment`.
"""
self.boundary = to_bytes(boundary)
self.content_length = content_length
Expand All @@ -321,13 +326,17 @@ def __init__(
self.max_segment_count = max_segment_count
self.strict = strict

self._delimiter = b"\r\n--" + self.boundary
if segment_class and issubclass(self.segment_class, MultipartSegment):
self.segment_class = segment_class
else:
self.segment_class = MultipartSegment

# Internal parser state
self._delimiter = b"\r\n--" + self.boundary
self._parsed = 0
self._fieldcount = 0
self._buffer = bytearray()
self._current = None
self._segment_count = 0
self._segment = None
self._state = _PREAMBLE

#: True if the parser reached the end of the multipart stream, stopped
Expand All @@ -344,7 +353,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):

def parse(
self, chunk: Union[bytes, bytearray]
) -> Iterator[Union["MultipartSegment", bytearray, None]]:
) -> Iterator[Union[t_segment, bytearray, None]]:
"""Parse a chunk of data and yield as many result objects as possible
with the data given.

Expand Down Expand Up @@ -406,7 +415,7 @@ def parse(
tail = buffer[next_start-2 : next_start]

if tail == b"\r\n": # Normal delimiter found
self._current = MultipartSegment(self)
self._segment = self._new_segment()
self._state = _HEADER
offset = next_start
continue
Expand All @@ -433,12 +442,12 @@ def parse(
nl = buffer.find(b"\r\n", offset)

if nl > offset: # Non-empty header line
self._current._add_headerline(buffer[offset:nl])
self._segment._on_headerline(buffer[offset:nl])
offset = nl + 2
continue
elif nl == offset: # Empty header line -> End of header section
self._current._close_headers()
yield self._current
self._segment._on_header_complete()
yield self._segment
self._state = _BODY
offset += 2
continue
Expand All @@ -463,27 +472,25 @@ def parse(

if tail == b"\r\n" or tail == b"--":
if index > offset:
self._current._update_size(index - offset)
yield buffer[offset:index]
yield self._segment._on_data(buffer[offset:index])

offset = next_start
self._current._mark_complete()
self._segment._on_data_complete()
yield None # End of segment

if tail == b"--": # Last delimiter
self._state = _COMPLETE
break
else: # Normal delimiter
self._current = MultipartSegment(self)
self._segment = self._new_segment()
self._state = _HEADER
continue

# Keep enough in buffer to accout for a partial delimiter at
# the end, but emiot the rest.
chunk_end = bufferlen - (d_len + 1)
assert chunk_end > offset # Always true
self._current._update_size(chunk_end - offset)
yield buffer[offset:chunk_end]
yield self._segment._on_data(buffer[offset:chunk_end])
offset = chunk_end
break # wait for more data

Expand All @@ -501,6 +508,12 @@ def parse(
self.close(check_complete=False)
raise

def _new_segment(self) -> t_segment:
self._segment_count += 1
if self._segment_count > self.max_segment_count:
raise ParserLimitReached("Maximum segment count exceeded")
return self.segment_class(self)

def close(self, check_complete=True):
"""
Close this parser if not already closed.
Expand All @@ -510,7 +523,7 @@ def close(self, check_complete=True):
"""

self.closed = True
self._current = None
self._segment = None
del self._buffer[:]

if check_complete and self._state is not _COMPLETE:
Expand Down Expand Up @@ -551,39 +564,34 @@ class MultipartSegment:
def __init__(self, parser: PushMultipartParser):
""" Private constructor, used by :class:`PushMultipartParser` """
self._parser = parser

if parser._fieldcount+1 > parser.max_segment_count:
raise ParserLimitReached("Maximum segment count exceeded")
parser._fieldcount += 1

self.headerlist = []
self.size = 0
self.complete = 0
self.complete = False

self.name = None
self.name = ""
self.filename = None
self.content_type = None
self.charset = None
self._maxlen = parser.max_segment_size
self._clen = -1
self._size_limit = parser.max_segment_size

def _add_headerline(self, line: bytearray):
assert line and self.name is None
parser = self._parser
def _on_headerline(self, line: bytearray):
""" Called for each raw header line in a segment. """

if line[0] in b" \t": # Multi-line header value
if not self.headerlist or parser.strict:
if line[0] in b" \t": # Continuation of last header line
if not self.headerlist or self._parser.strict:
raise StrictParserError("Unexpected segment header continuation")
prev = ": ".join(self.headerlist.pop())
line = prev.encode(parser.header_charset) + b" " + line.strip()
line = prev.encode(self._parser.header_charset) + b" " + line.strip()

if len(line) > parser.max_header_size:
if len(line) > self._parser.max_header_size:
raise ParserLimitReached("Maximum segment header length exceeded")
if len(self.headerlist) >= parser.max_header_count:

if len(self.headerlist) >= self._parser.max_header_count:
raise ParserLimitReached("Maximum segment header count exceeded")

try:
name, col, value = line.decode(parser.header_charset).partition(":")
name, col, value = line.decode(self._parser.header_charset).partition(":")
name = name.strip()
if not col or not name:
raise ParserError("Malformed segment header")
Expand All @@ -594,9 +602,10 @@ def _add_headerline(self, line: bytearray):

self.headerlist.append((name.title(), value.strip()))

def _close_headers(self):
assert self.name is None
def _on_header_complete(self):
""" Called after the last segment header. """

dtype = False
for h,v in self.headerlist:
if h == "Content-Disposition":
dtype, args = parse_options_header(v, unquote=content_disposition_unquote)
Expand All @@ -611,21 +620,23 @@ def _close_headers(self):
self.charset = args.get("charset")
elif h == "Content-Length" and v.isdecimal():
self._clen = int(v)
self._maxlen = min(self._clen, self._maxlen)

if self.name is None:
if not dtype:
raise ParserError("Missing Content-Disposition segment header")

def _update_size(self, bytecount: int):
assert self.name is not None and not self.complete
self.size += bytecount
if self._clen >= 0 and self.size > self._clen:
raise ParserError("Segment Content-Length exceeded")
if self.size > self._size_limit:
def _on_data(self, chunk: bytearray) -> bytearray:
""" Called for each chunk of segment data. Must return the chunk. """
self.size += len(chunk)
if self.size > self._maxlen:
if self.size > self._clen > -1:
raise ParserError("Segment Content-Length exceeded")
raise ParserLimitReached("Maximum segment size exceeded")
return chunk

def _mark_complete(self):
assert self.name is not None and not self.complete
if self._clen >= 0 and self.size != self._clen:
def _on_data_complete(self):
""" Called after the last chunk of segment data. """
if self._clen > -1 and self.size != self._clen:
raise ParserError("Segment size does not match Content-Length header")
self.complete = True

Expand Down
Loading