forked from mhammond/pywin32
-
Notifications
You must be signed in to change notification settings - Fork 0
/
threaded_extension.py
188 lines (165 loc) · 7.15 KB
/
threaded_extension.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
"""An ISAPI extension base class implemented using a thread-pool."""
# $Id$
import sys
import time
from isapi import isapicon, ExtensionError
import isapi.simple
from win32file import (
GetQueuedCompletionStatus,
CreateIoCompletionPort,
PostQueuedCompletionStatus,
CloseHandle,
)
from win32security import SetThreadToken
from win32event import INFINITE
from pywintypes import OVERLAPPED
import threading
import traceback
ISAPI_REQUEST = 1
ISAPI_SHUTDOWN = 2
class WorkerThread(threading.Thread):
def __init__(self, extension, io_req_port):
self.running = False
self.io_req_port = io_req_port
self.extension = extension
threading.Thread.__init__(self)
# We wait 15 seconds for a thread to terminate, but if it fails to,
# we don't want the process to hang at exit waiting for it...
self.setDaemon(True)
def run(self):
self.running = True
while self.running:
errCode, bytes, key, overlapped = GetQueuedCompletionStatus(
self.io_req_port, INFINITE
)
if key == ISAPI_SHUTDOWN and overlapped is None:
break
# Let the parent extension handle the command.
dispatcher = self.extension.dispatch_map.get(key)
if dispatcher is None:
raise RuntimeError("Bad request '%s'" % (key,))
dispatcher(errCode, bytes, key, overlapped)
def call_handler(self, cblock):
self.extension.Dispatch(cblock)
# A generic thread-pool based extension, using IO Completion Ports.
# Sub-classes can override one method to implement a simple extension, or
# may leverage the CompletionPort to queue their own requests, and implement a
# fully asynch extension.
class ThreadPoolExtension(isapi.simple.SimpleExtension):
"Base class for an ISAPI extension based around a thread-pool"
max_workers = 20
worker_shutdown_wait = 15000 # 15 seconds for workers to quit...
def __init__(self):
self.workers = []
# extensible dispatch map, for sub-classes that need to post their
# own requests to the completion port.
# Each of these functions is called with the result of
# GetQueuedCompletionStatus for our port.
self.dispatch_map = {
ISAPI_REQUEST: self.DispatchConnection,
}
def GetExtensionVersion(self, vi):
isapi.simple.SimpleExtension.GetExtensionVersion(self, vi)
# As per Q192800, the CompletionPort should be created with the number
# of processors, even if the number of worker threads is much larger.
# Passing 0 means the system picks the number.
self.io_req_port = CreateIoCompletionPort(-1, None, 0, 0)
# start up the workers
self.workers = []
for i in range(self.max_workers):
worker = WorkerThread(self, self.io_req_port)
worker.start()
self.workers.append(worker)
def HttpExtensionProc(self, control_block):
overlapped = OVERLAPPED()
overlapped.object = control_block
PostQueuedCompletionStatus(self.io_req_port, 0, ISAPI_REQUEST, overlapped)
return isapicon.HSE_STATUS_PENDING
def TerminateExtension(self, status):
for worker in self.workers:
worker.running = False
for worker in self.workers:
PostQueuedCompletionStatus(self.io_req_port, 0, ISAPI_SHUTDOWN, None)
# wait for them to terminate - pity we aren't using 'native' threads
# as then we could do a smart wait - but now we need to poll....
end_time = time.time() + self.worker_shutdown_wait / 1000
alive = self.workers
while alive:
if time.time() > end_time:
# xxx - might be nice to log something here.
break
time.sleep(0.2)
alive = [w for w in alive if w.is_alive()]
self.dispatch_map = {} # break circles
CloseHandle(self.io_req_port)
# This is the one operation the base class supports - a simple
# Connection request. We setup the thread-token, and dispatch to the
# sub-class's 'Dispatch' method.
def DispatchConnection(self, errCode, bytes, key, overlapped):
control_block = overlapped.object
# setup the correct user for this request
hRequestToken = control_block.GetImpersonationToken()
SetThreadToken(None, hRequestToken)
try:
try:
self.Dispatch(control_block)
except:
self.HandleDispatchError(control_block)
finally:
# reset the security context
SetThreadToken(None, None)
def Dispatch(self, ecb):
"""Overridden by the sub-class to handle connection requests.
This class creates a thread-pool using a Windows completion port,
and dispatches requests via this port. Sub-classes can generally
implement each connection request using blocking reads and writes, and
the thread-pool will still provide decent response to the end user.
The sub-class can set a max_workers attribute (default is 20). Note
that this generally does *not* mean 20 threads will all be concurrently
running, via the magic of Windows completion ports.
There is no default implementation - sub-classes must implement this.
"""
raise NotImplementedError("sub-classes should override Dispatch")
def HandleDispatchError(self, ecb):
"""Handles errors in the Dispatch method.
When a Dispatch method call fails, this method is called to handle
the exception. The default implementation formats the traceback
in the browser.
"""
ecb.HttpStatusCode = isapicon.HSE_STATUS_ERROR
# control_block.LogData = "we failed!"
exc_typ, exc_val, exc_tb = sys.exc_info()
limit = None
try:
try:
import cgi
ecb.SendResponseHeaders(
"200 OK", "Content-type: text/html\r\n\r\n", False
)
print(file=ecb)
print("<H3>Traceback (most recent call last):</H3>", file=ecb)
list = traceback.format_tb(
exc_tb, limit
) + traceback.format_exception_only(exc_typ, exc_val)
print(
"<PRE>%s<B>%s</B></PRE>"
% (
cgi.escape("".join(list[:-1])),
cgi.escape(list[-1]),
),
file=ecb,
)
except ExtensionError:
# The client disconnected without reading the error body -
# its probably not a real browser at the other end, ignore it.
pass
except:
print("FAILED to render the error message!")
traceback.print_exc()
print("ORIGINAL extension error:")
traceback.print_exception(exc_typ, exc_val, exc_tb)
finally:
# holding tracebacks in a local of a frame that may itself be
# part of a traceback used to be evil and cause leaks!
exc_tb = None
ecb.DoneWithSession()