Skip to content

Commit

Permalink
Limit origin check to when host is loopback.
Browse files Browse the repository at this point in the history
This should still prevent the exploit without breaking things for people
who use reverse proxies.
  • Loading branch information
comfyanonymous committed Sep 11, 2024
1 parent 81778a7 commit 36c83cd
Showing 1 changed file with 33 additions and 1 deletion.
34 changes: 33 additions & 1 deletion server.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import glob
import struct
import ssl
import socket
import ipaddress
from PIL import Image, ImageOps
from PIL.PngImagePlugin import PngInfo
from io import BytesIO
Expand Down Expand Up @@ -80,6 +82,32 @@ async def cors_middleware(request: web.Request, handler):

return cors_middleware

def is_loopback(host):
if host is None:
return False
try:
if ipaddress.ip_address(host).is_loopback:
return True
else:
return False
except:
pass

loopback = False
for family in (socket.AF_INET, socket.AF_INET6):
try:
r = socket.getaddrinfo(host, None, family, socket.SOCK_STREAM)
for family, _, _, _, sockaddr in r:
if not ipaddress.ip_address(sockaddr[0]).is_loopback:
return loopback
else:
loopback = True
except socket.gaierror:
pass

return loopback


def create_origin_only_middleware():
@web.middleware
async def origin_only_middleware(request: web.Request, handler):
Expand All @@ -93,12 +121,16 @@ async def origin_only_middleware(request: web.Request, handler):
parsed = urllib.parse.urlparse(origin)
origin_domain = parsed.netloc.lower()
host_domain_parsed = urllib.parse.urlsplit('//' + host_domain)

#limit the check to when the host domain is localhost, this makes it slightly less safe but should still prevent the exploit
loopback = is_loopback(host_domain_parsed.hostname)

if parsed.port is None: #if origin doesn't have a port strip it from the host to handle weird browsers, same for host
host_domain = host_domain_parsed.hostname
if host_domain_parsed.port is None:
origin_domain = parsed.hostname

if host_domain is not None and origin_domain is not None and len(host_domain) > 0 and len(origin_domain) > 0:
if loopback and host_domain is not None and origin_domain is not None and len(host_domain) > 0 and len(origin_domain) > 0:
if host_domain != origin_domain:
logging.warning("WARNING: request with non matching host and origin {} != {}, returning 403".format(host_domain, origin_domain))
return web.Response(status=403)
Expand Down

0 comments on commit 36c83cd

Please sign in to comment.