diff --git a/tst/memfs/memfs.cpp b/tst/memfs/memfs.cpp index 697893f8..06ebf2bb 100644 --- a/tst/memfs/memfs.cpp +++ b/tst/memfs/memfs.cpp @@ -192,7 +192,7 @@ typedef struct _MEMFS_FILE_NODE SIZE_T ReparseDataSize; PVOID ReparseData; #endif - ULONG RefCount; + volatile LONG RefCount; #if defined(MEMFS_NAMED_STREAMS) struct _MEMFS_FILE_NODE *MainFileNode; #endif @@ -260,13 +260,13 @@ VOID MemfsFileNodeDelete(MEMFS_FILE_NODE *FileNode) static inline VOID MemfsFileNodeReference(MEMFS_FILE_NODE *FileNode) { - FileNode->RefCount++; + InterlockedIncrement(&FileNode->RefCount); } static inline VOID MemfsFileNodeDereference(MEMFS_FILE_NODE *FileNode) { - if (0 == --FileNode->RefCount) + if (0 == InterlockedDecrement(&FileNode->RefCount)) MemfsFileNodeDelete(FileNode); } @@ -850,14 +850,18 @@ static NTSTATUS Overwrite(FSP_FILE_SYSTEM *FileSystem, NTSTATUS Result; #if defined(MEMFS_NAMED_STREAMS) - MEMFS_FILE_NODE_MAP_ENUM_CONTEXT Context = { FALSE }; + MEMFS_FILE_NODE_MAP_ENUM_CONTEXT Context = { TRUE }; ULONG Index; MemfsFileNodeMapEnumerateNamedStreams(Memfs->FileNodeMap, FileNode, MemfsFileNodeMapEnumerateFn, &Context); for (Index = 0; Context.Count > Index; Index++) - if (1 >= Context.FileNodes[Index]->RefCount) + { + LONG RefCount = Context.FileNodes[Index]->RefCount; + MemoryBarrier(); + if (2 >= RefCount) MemfsFileNodeMapRemove(Memfs->FileNodeMap, Context.FileNodes[Index]); + } MemfsFileNodeMapEnumerateFree(&Context); #endif