diff --git a/tst/memfs/memfs.cpp b/tst/memfs/memfs.cpp index 9d7bf8a1..63ba2b82 100644 --- a/tst/memfs/memfs.cpp +++ b/tst/memfs/memfs.cpp @@ -518,11 +518,18 @@ NTSTATUS Overwrite(FSP_FILE_SYSTEM *FileSystem, return STATUS_SUCCESS; } +typedef struct _MEMFS_CLEANUP_CONTEXT +{ + MEMFS_FILE_NODE **FileNodes; + ULONG Count; +} MEMFS_CLEANUP_CONTEXT; + static BOOLEAN CleanupEnumFn(MEMFS_FILE_NODE *FileNode, PVOID Context0) { - MEMFS *Memfs = (MEMFS *)Context0; + MEMFS_CLEANUP_CONTEXT *Context = (MEMFS_CLEANUP_CONTEXT *)Context0; + + Context->FileNodes[Context->Count++] = FileNode; - MemfsFileNodeMapRemove(Memfs->FileNodeMap, FileNode); return TRUE; } @@ -538,7 +545,18 @@ static VOID Cleanup(FSP_FILE_SYSTEM *FileSystem, if (Delete && !MemfsFileNodeMapHasChild(Memfs->FileNodeMap, FileNode)) { - MemfsFileNodeMapEnumerateNamedStreams(Memfs->FileNodeMap, FileNode, CleanupEnumFn, Memfs); + MEMFS_CLEANUP_CONTEXT Context = { 0 }; + ULONG Index; + + Context.FileNodes = (MEMFS_FILE_NODE **)malloc(Memfs->MaxFileNodes * sizeof Context.FileNodes[0]); + if (0 != Context.FileNodes) + { + MemfsFileNodeMapEnumerateNamedStreams(Memfs->FileNodeMap, FileNode, CleanupEnumFn, &Context); + for (Index = 0; Context.Count > Index; Index++) + MemfsFileNodeMapRemove(Memfs->FileNodeMap, Context.FileNodes[Index]); + free(Context.FileNodes); + } + MemfsFileNodeMapRemove(Memfs->FileNodeMap, FileNode); } }