1 module eventcore.drivers.winapi.core;
2 
3 version (Windows):
4 
5 import eventcore.driver;
6 import eventcore.drivers.timer;
7 import eventcore.internal.consumablequeue;
8 import eventcore.internal.utils : mallocT, freeT, nogc_assert, print;
9 import eventcore.internal.win32;
10 import core.sync.mutex : Mutex;
11 import core.time : Duration;
12 import taggedalgebraic;
13 import std.stdint : intptr_t;
14 import std.typecons : Tuple, tuple;
15 
16 
17 final class WinAPIEventDriverCore : EventDriverCore {
18 @safe: /*@nogc:*/ nothrow:
19 	private alias ThreadCallbackEntry = Tuple!(ThreadCallbackGen, ThreadCallbackGenParams);
20 
21 	private {
22 		bool m_exit;
23 		size_t m_waiterCount;
24 		DWORD m_tid;
25 		LoopTimeoutTimerDriver m_timers;
26 		HANDLE[MAXIMUM_WAIT_OBJECTS] m_registeredEvents;
27 		void delegate() @safe nothrow[MAXIMUM_WAIT_OBJECTS] m_registeredEventCallbacks;
28 		DWORD m_registeredEventCount = 0;
29 		HANDLE m_fileCompletionEvent;
30 		uint m_validationCounter;
31 		ConsumableQueue!IOEvent m_ioEvents;
32 
33 		shared Mutex m_threadCallbackMutex;
34 		ConsumableQueue!ThreadCallbackEntry m_threadCallbacks;
35 	}
36 
37 	package {
38 		HandleSlot[HANDLE] m_handles; // FIXME: use allocator based hash map
39 	}
40 
41 	this(LoopTimeoutTimerDriver timers)
42 	@nogc {
43 		m_timers = timers;
44 		m_tid = () @trusted { return GetCurrentThreadId(); } ();
45 		m_fileCompletionEvent = () @trusted { return CreateEventW(null, false, false, null); } ();
46 		registerEvent(m_fileCompletionEvent);
47 		m_ioEvents = mallocT!(ConsumableQueue!IOEvent);
48         m_threadCallbackMutex = mallocT!(shared(Mutex));
49 		m_threadCallbacks = mallocT!(ConsumableQueue!ThreadCallbackEntry);
50 		m_threadCallbacks.reserve(1000);
51 	}
52 
53 	void dispose()
54 	@trusted {
55 		try {
56 			freeT(m_threadCallbacks);
57 			freeT(m_threadCallbackMutex);
58 			freeT(m_ioEvents);
59 		} catch (Exception e) assert(false, e.msg);
60 	}
61 
62 	package bool checkForLeakedHandles()
63 	@trusted {
64 		import core.thread : Thread;
65 
66 		static string getThreadName()
67 		{
68 			string thname;
69 			try thname = Thread.getThis().name;
70 			catch (Exception e) assert(false, e.msg);
71 			return thname.length ? thname : "unknown";
72 		}
73 
74 		foreach (k; m_handles.byKey) {
75 			print("Warning (thread: %s): Leaked handles detected at driver shutdown", getThreadName());
76 			foreach (ks; m_handles.byKeyValue)
77 				if (!ks.value.specific.hasType!(typeof(null)))
78 					print("   FD %04X (%s)", ks.key, ks.value.specific.kind);
79 			return true;
80 		}
81 
82 		return false;
83 	}
84 
85 	override size_t waiterCount() { return m_waiterCount + m_timers.pendingCount; }
86 
87 	package void addWaiter() @nogc { m_waiterCount++; }
88 	package void removeWaiter()
89 	@nogc {
90 		assert(m_waiterCount > 0, "Decrementing waiter count below zero.");
91 		m_waiterCount--;
92 	}
93 
94 	override ExitReason processEvents(Duration timeout = Duration.max)
95 	{
96 		import std.algorithm : min;
97 		import core.time : MonoTime, seconds;
98 
99 		if (m_exit) {
100 			m_exit = false;
101 			return ExitReason.exited;
102 		}
103 
104 		if (!waiterCount) return ExitReason.outOfWaiters;
105 
106 		bool got_event;
107 		auto now = MonoTime.currTime;
108 		do {
109 			auto nextto = min(m_timers.getNextTimeout(now), timeout);
110 			got_event |= doProcessEvents(nextto);
111 			auto prev_step = now;
112 			now = MonoTime.currTime;
113 			got_event |= m_timers.process(now);
114 
115 			if (m_exit) {
116 				m_exit = false;
117 				return ExitReason.exited;
118 			} else if (got_event) break;
119 			if (timeout != Duration.max)
120 				timeout -= now - prev_step;
121 		} while (timeout > 0.seconds);
122 
123 		if (!waiterCount) return ExitReason.outOfWaiters;
124 		if (got_event) return ExitReason.idle;
125 		return ExitReason.timeout;
126 	}
127 
128 	override void exit()
129 	@trusted {
130 		m_exit = true;
131 		PostThreadMessageW(m_tid, WM_QUIT, 0, 0);
132 	}
133 
134 	override void clearExitFlag()
135 	{
136 		m_exit = false;
137 	}
138 
139 	override void runInOwnerThread(ThreadCallbackGen del,
140 		ref ThreadCallbackGenParams params)
141 	shared {
142 		import core.atomic : atomicLoad;
143 
144 		auto m = atomicLoad(m_threadCallbackMutex);
145 		// NOTE: This case must be handled gracefully to avoid hazardous
146 		//       race-conditions upon unexpected thread termination. The mutex
147 		//       and the map will stay valid even after the driver has been
148 		//       disposed, so no further synchronization is required.
149 		if (!m) return;
150 
151 		try {
152 			synchronized (m)
153 				() @trusted { return (cast()this).m_threadCallbacks; } ()
154 					.put(ThreadCallbackEntry(del, params));
155 		} catch (Exception e) assert(false, e.msg);
156 
157 		() @trusted { PostThreadMessageW(m_tid, WM_APP, 0, 0); } ();
158 	}
159 
160 	alias runInOwnerThread = EventDriverCore.runInOwnerThread;
161 
162 	package void* rawUserDataImpl(HANDLE handle, size_t size, DataInitializer initialize, DataInitializer destroy)
163 	@system {
164 		HandleSlot* fds = &m_handles[handle];
165 		assert(fds.userDataDestructor is null || fds.userDataDestructor is destroy,
166 			"Requesting user data with differing type (destructor).");
167 		assert(size <= HandleSlot.userData.length, "Requested user data is too large.");
168 		if (size > HandleSlot.userData.length) assert(false);
169 		if (!fds.userDataDestructor) {
170 			initialize(fds.userData.ptr);
171 			fds.userDataDestructor = destroy;
172 		}
173 		return fds.userData.ptr;
174 	}
175 
176 	protected override void* rawUserData(StreamSocketFD descriptor, size_t size, DataInitializer initialize, DataInitializer destroy) @system
177 	{
178 		assert(false, "TODO!");
179 	}
180 
181 	protected override void* rawUserData(DatagramSocketFD descriptor, size_t size, DataInitializer initialize, DataInitializer destroy) @system
182 	{
183 		assert(false, "TODO!");
184 	}
185 
186 	private bool doProcessEvents(Duration max_wait)
187 	{
188 		import core.time : seconds;
189 		import std.algorithm.comparison : min, max;
190 
191 		executeThreadCallbacks();
192 
193 		bool got_event;
194 
195 		DWORD timeout_msecs = max_wait == Duration.max ? INFINITE : cast(DWORD)min(max(max_wait.total!"msecs", 0), DWORD.max);
196 		auto ret = () @trusted { return MsgWaitForMultipleObjectsEx(m_registeredEventCount, m_registeredEvents.ptr,
197 			timeout_msecs, QS_ALLEVENTS, MWMO_ALERTABLE|MWMO_INPUTAVAILABLE); } ();
198 
199 		while (!m_ioEvents.empty) {
200 			auto evt = m_ioEvents.consumeOne();
201 			evt.process(evt.error, evt.bytesTransferred, evt.overlapped);
202 		}
203 
204 		if (ret == WAIT_IO_COMPLETION) got_event = true;
205 		else if (ret >= WAIT_OBJECT_0 && ret < WAIT_OBJECT_0 + m_registeredEventCount) {
206 			if (auto cb = m_registeredEventCallbacks[ret - WAIT_OBJECT_0]) {
207 				cb();
208 				got_event = true;
209 			}
210 		}
211 
212 		/*if (ret == WAIT_OBJECT_0) {
213 			got_event = true;
214 			Win32TCPConnection[] to_remove;
215 			foreach (fw; m_fileWriters.byKey)
216 				if (fw.testFileWritten())
217 					to_remove ~= fw;
218 			foreach (fw; to_remove)
219 			m_fileWriters.remove(fw);
220 		}*/
221 
222 		MSG msg;
223 		//uint cnt = 0;
224 		while (() @trusted { return PeekMessageW(&msg, null, 0, 0, PM_REMOVE); } ()) {
225 			if (msg.message == WM_QUIT && m_exit)
226 				break;
227 
228 			() @trusted {
229 				TranslateMessage(&msg);
230 				DispatchMessageW(&msg);
231 			} ();
232 
233 			got_event = true;
234 
235 			// process timers every now and then so that they don't get stuck
236 			//if (++cnt % 10 == 0) processTimers();
237 		}
238 
239 		executeThreadCallbacks();
240 
241 		return got_event;
242 	}
243 
244 
245 	package void registerEvent(HANDLE event, void delegate() @safe nothrow callback = null)
246 	@nogc {
247 		assert(m_registeredEventCount < MAXIMUM_WAIT_OBJECTS, "Too many registered events.");
248 		m_registeredEvents[m_registeredEventCount] = event;
249 		if (callback) m_registeredEventCallbacks[m_registeredEventCount] = callback;
250 		m_registeredEventCount++;
251 	}
252 
253 	package SlotType* setupSlot(SlotType)(HANDLE h)
254 	{
255 		assert(h !in m_handles, "Handle already in use.");
256 		HandleSlot s;
257 		s.refCount = 1;
258 		s.validationCounter = ++m_validationCounter;
259 		s.specific = SlotType.init;
260 		m_handles[h] = s;
261 		return () @trusted { return &m_handles[h].specific.get!SlotType(); } ();
262 	}
263 
264 	package void freeSlot(HANDLE h)
265 	{
266 		nogc_assert((h in m_handles) !is null, "Handle not in use - cannot free.");
267 		m_handles.remove(h);
268 	}
269 
270 	package void discardEvents(scope OVERLAPPED_CORE*[] overlapped...)
271 @nogc	{
272 		import std.algorithm.searching : canFind;
273 		m_ioEvents.filterPending!(evt => !overlapped.canFind(evt.overlapped));
274 	}
275 
276 	private void executeThreadCallbacks()
277 	{
278 		while (true) {
279 			ThreadCallbackEntry del;
280 			try {
281 				synchronized (m_threadCallbackMutex) {
282 					if (m_threadCallbacks.empty) break;
283 					del = m_threadCallbacks.consumeOne;
284 				}
285 			} catch (Exception e) assert(false, e.msg);
286 			del[0](del[1]);
287 		}
288 	}
289 }
290 
291 private long currStdTime()
292 @safe nothrow {
293 	import std.datetime : Clock;
294 	scope (failure) assert(false);
295 	return Clock.currStdTime;
296 }
297 
298 private struct HandleSlot {
299 	static union SpecificTypes {
300 		typeof(null) none;
301 		FileSlot files;
302 		WatcherSlot watcher;
303 	}
304 	int refCount;
305 	uint validationCounter;
306 	TaggedAlgebraic!SpecificTypes specific;
307 
308 	DataInitializer userDataDestructor;
309 	ubyte[16*size_t.sizeof] userData;
310 
311 	@safe nothrow:
312 
313 	@property ref FileSlot file() { return specific.get!FileSlot; }
314 	@property ref WatcherSlot watcher() { return specific.get!WatcherSlot; }
315 
316 	void addRef()
317 	{
318 		assert(refCount > 0);
319 		refCount++;
320 	}
321 
322 	bool releaseRef(scope void delegate() @safe nothrow on_free)
323 	{
324 		nogc_assert(refCount > 0, "Releasing unreferenced slot.");
325 		if (--refCount == 0) {
326 			on_free();
327 			return false;
328 		}
329 		return true;
330 	}
331 }
332 
333 package struct FileSlot {
334 	static struct Direction(bool RO) {
335 		OVERLAPPED_CORE overlapped;
336 		FileIOCallback callback;
337 		ulong offset;
338 		size_t bytesTransferred;
339 		IOMode mode;
340 		static if (RO) const(ubyte)[] buffer;
341 		else ubyte[] buffer;
342 
343 		void invokeCallback(IOStatus status, size_t bytes_transferred)
344 		@safe nothrow {
345 			auto cb = this.callback;
346 			this.callback = null;
347 			assert(cb !is null);
348 			if (auto ps = overlapped.hEvent in overlapped.driver.m_handles) {
349 				auto vc = ps.validationCounter;
350 				auto fd = FileFD(cast(size_t)overlapped.hEvent, vc);
351 				cb(fd, status, bytes_transferred);
352 			}
353 		}
354 	}
355 	Direction!false read;
356 	Direction!true write;
357 	FileCloseCallback closeCallback;
358 }
359 
360 package struct WatcherSlot {
361 	ubyte[] buffer;
362 	OVERLAPPED_CORE overlapped;
363 	string directory;
364 	bool recursive;
365 	FileChangesCallback callback;
366 }
367 
368 package struct OVERLAPPED_CORE {
369 	OVERLAPPED overlapped;
370 	alias overlapped this;
371 	WinAPIEventDriverCore driver;
372 }
373 
374 package struct IOEvent {
375 	void function(DWORD err, DWORD bts, OVERLAPPED_CORE*) @safe nothrow process;
376 	DWORD error;
377 	DWORD bytesTransferred;
378 	OVERLAPPED_CORE* overlapped;
379 }
380 
381 package extern(System) @system nothrow
382 void overlappedIOHandler(alias process, EXTRA...)(DWORD error, DWORD bytes_transferred, OVERLAPPED* _overlapped, EXTRA extra)
383 {
384 	auto overlapped = cast(OVERLAPPED_CORE*)_overlapped;
385 	overlapped.driver.m_ioEvents.put(IOEvent(&process, error, bytes_transferred, overlapped));
386 }