Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions app/api/v2/handlers/payload_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,31 @@
PayloadDeleteRequestSchema


ALLOWED_EXTENSIONS = [
'.ps1', '.sh', '.py', '.exe', '.elf', '.bat', '.vbs', '.js', '.go', '.c',
'.zip', '.tar', '.gz', '.dll', '.bin', '.yaml', '.yml', '.txt', '.json',
]

DANGEROUS_MAGIC_BYTES = [
b'<?php', b'<%@', b'<%!', b'<%@ Page',
]


def _validate_payload_file(filename, file_content_start):
"""Validate payload filename extension and magic bytes.
Returns (is_valid, error_message).
"""
if '\x00' in filename:
return False, 'Null byte detected in filename'
ext = os.path.splitext(filename)[1].lower()
if ext and ext not in ALLOWED_EXTENSIONS:
return False, f'File extension not allowed: {ext}'
for magic in DANGEROUS_MAGIC_BYTES:
if file_content_start.startswith(magic):
Comment on lines +36 to +37
return False, f'Dangerous file signature detected'
Comment on lines +36 to +38
return True, ''


class PayloadApi(BaseApi):
def __init__(self, services):
super().__init__(auth_svc=services['auth_svc'])
Expand Down Expand Up @@ -70,6 +95,13 @@ async def post_payloads(self, request: web.Request):
# accessing the file using the prefilled request["form"] dictionary.
file_field: web.FileField = request["form"]["file"]

# Validate filename and magic bytes
first_bytes = file_field.file.read(16)
file_field.file.seek(0)
is_valid, error_msg = _validate_payload_file(file_field.filename, first_bytes)
if not is_valid:
raise web.HTTPBadRequest(text=error_msg)

# Sanitize the file name to prevent directory traversal
sanitized_filename = self.sanitize_filename(file_field.filename)

Expand Down
43 changes: 43 additions & 0 deletions tests/security/test_payload_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import unittest
import os
import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..'))


class TestPayloadValidation(unittest.TestCase):
def test_valid_extension(self):
from app.api.v2.handlers.payload_api import _validate_payload_file
ok, _ = _validate_payload_file('test.ps1', b'\x00\x00\x00\x00')
self.assertTrue(ok)

def test_invalid_extension(self):
from app.api.v2.handlers.payload_api import _validate_payload_file
ok, msg = _validate_payload_file('test.php', b'normal content')
self.assertFalse(ok)
self.assertIn('extension', msg.lower())

def test_dangerous_magic_bytes_php(self):
from app.api.v2.handlers.payload_api import _validate_payload_file
ok, msg = _validate_payload_file('test.txt', b'<?php echo "hi";')
self.assertFalse(ok)
self.assertIn('Dangerous', msg)

def test_dangerous_magic_bytes_jsp(self):
from app.api.v2.handlers.payload_api import _validate_payload_file
ok, msg = _validate_payload_file('test.txt', b'<%@ page import')
self.assertFalse(ok)

def test_null_byte_in_filename(self):
from app.api.v2.handlers.payload_api import _validate_payload_file
ok, msg = _validate_payload_file('test\x00.txt', b'safe')
self.assertFalse(ok)
self.assertIn('Null byte', msg)

def test_no_extension_is_allowed(self):
from app.api.v2.handlers.payload_api import _validate_payload_file
ok, _ = _validate_payload_file('myagent', b'\x7fELF')
self.assertTrue(ok)


if __name__ == '__main__':
unittest.main()
Loading