diff --git a/tst/memfs/memfs.cpp b/tst/memfs/memfs.cpp index 3c5373d0..7ae5adfc 100644 --- a/tst/memfs/memfs.cpp +++ b/tst/memfs/memfs.cpp @@ -237,6 +237,22 @@ BOOLEAN MemfsFileNodeMapEnumerateChildren(MEMFS_FILE_NODE_MAP *FileNodeMap, MEMF return TRUE; } +static inline +BOOLEAN MemfsFileNodeMapEnumerateDescendants(MEMFS_FILE_NODE_MAP *FileNodeMap, MEMFS_FILE_NODE *FileNode, + BOOLEAN (*EnumFn)(MEMFS_FILE_NODE *, PVOID), PVOID Context) +{ + WCHAR Root[2] = L"\\"; + MEMFS_FILE_NODE_MAP::iterator iter = FileNodeMap->lower_bound(FileNode->FileName); + for (; FileNodeMap->end() != iter; ++iter) + { + if (!MemfsFileNameHasPrefix(iter->second->FileName, FileNode->FileName)) + break; + if (!EnumFn(iter->second, Context)) + return FALSE; + } + return TRUE; +} + static NTSTATUS SetFileSize(FSP_FILE_SYSTEM *FileSystem, FSP_FSCTL_TRANSACT_REQ *Request, PVOID FileNode0, UINT64 FileSize, @@ -630,6 +646,22 @@ static NTSTATUS CanDelete(FSP_FILE_SYSTEM *FileSystem, return STATUS_SUCCESS; } +typedef struct _MEMFS_RENAME_CONTEXT +{ + MEMFS_FILE_NODE **FileNodes; + ULONG Count; +} MEMFS_RENAME_CONTEXT; + +static BOOLEAN RenameEnumFn(MEMFS_FILE_NODE *FileNode, PVOID Context0) +{ + MEMFS_RENAME_CONTEXT *Context = (MEMFS_RENAME_CONTEXT *)Context0; + + Context->FileNodes[Context->Count++] = FileNode; + FileNode->RefCount++; + + return TRUE; +} + static NTSTATUS Rename(FSP_FILE_SYSTEM *FileSystem, FSP_FSCTL_TRANSACT_REQ *Request, PVOID FileNode0, @@ -637,40 +669,88 @@ static NTSTATUS Rename(FSP_FILE_SYSTEM *FileSystem, { MEMFS *Memfs = (MEMFS *)FileSystem->UserContext; MEMFS_FILE_NODE *FileNode = (MEMFS_FILE_NODE *)FileNode0; - MEMFS_FILE_NODE *NewFileNode; + MEMFS_FILE_NODE *NewFileNode, *DescendantFileNode; + MEMFS_RENAME_CONTEXT Context = { 0 }; + ULONG Index, FileNameLen, NewFileNameLen; BOOLEAN Inserted; NTSTATUS Result; assert(0 == FileName || 0 == wcscmp(FileNode->FileName, FileName)); - if (MAX_PATH <= wcslen(NewFileName)) - return STATUS_OBJECT_NAME_INVALID; - NewFileNode = MemfsFileNodeMapGet(Memfs->FileNodeMap, NewFileName); if (0 != NewFileNode) { if (!ReplaceIfExists) - return STATUS_OBJECT_NAME_COLLISION; + { + Result = STATUS_OBJECT_NAME_COLLISION; + goto exit; + } if (NewFileNode->FileInfo.FileAttributes & FILE_ATTRIBUTE_DIRECTORY) - return STATUS_ACCESS_DENIED; + { + Result = STATUS_ACCESS_DENIED; + goto exit; + } + } + Context.FileNodes = (MEMFS_FILE_NODE **)malloc(Memfs->MaxFileNodes * sizeof Context.FileNodes[0]); + if (0 == Context.FileNodes) + { + Result = STATUS_INSUFFICIENT_RESOURCES; + goto exit; + } + + MemfsFileNodeMapEnumerateDescendants(Memfs->FileNodeMap, FileNode, RenameEnumFn, &Context); + + FileNameLen = (ULONG)wcslen(FileNode->FileName); + NewFileNameLen = (ULONG)wcslen(NewFileName); + for (Index = 0; Context.Count > Index; Index++) + { + DescendantFileNode = Context.FileNodes[Index]; + assert(MemfsFileNameHasPrefix(DescendantFileNode->FileName, FileNode->FileName)); + if (MAX_PATH <= wcslen(DescendantFileNode->FileName) - FileNameLen + NewFileNameLen) + { + Result = STATUS_OBJECT_NAME_INVALID; + goto exit; + } + } + + if (0 != NewFileNode) + { NewFileNode->RefCount++; MemfsFileNodeMapRemove(Memfs->FileNodeMap, NewFileNode); if (0 == --NewFileNode->RefCount) MemfsFileNodeDelete(NewFileNode); } - MemfsFileNodeMapRemove(Memfs->FileNodeMap, FileNode); - wcscpy_s(FileNode->FileName, sizeof FileNode->FileName / sizeof(WCHAR), NewFileName); - Result = MemfsFileNodeMapInsert(Memfs->FileNodeMap, FileNode, &Inserted); - if (!NT_SUCCESS(Result)) + for (Index = 0; Context.Count > Index; Index++) { - FspDebugLog(__FUNCTION__ ": cannot insert into FileNodeMap; aborting\n"); - abort(); + DescendantFileNode = Context.FileNodes[Index]; + MemfsFileNodeMapRemove(Memfs->FileNodeMap, DescendantFileNode); + memmove(DescendantFileNode->FileName + NewFileNameLen, + DescendantFileNode->FileName + FileNameLen, + (wcslen(DescendantFileNode->FileName) + 1 - FileNameLen) * sizeof(WCHAR)); + memcpy(DescendantFileNode->FileName, NewFileName, NewFileNameLen * sizeof(WCHAR)); + Result = MemfsFileNodeMapInsert(Memfs->FileNodeMap, DescendantFileNode, &Inserted); + if (!NT_SUCCESS(Result)) + { + FspDebugLog(__FUNCTION__ ": cannot insert into FileNodeMap; aborting\n"); + abort(); + } + assert(Inserted); } - return STATUS_SUCCESS; + Result = STATUS_SUCCESS; + +exit: + for (Index = 0; Context.Count > Index; Index++) + { + DescendantFileNode = Context.FileNodes[Index]; + DescendantFileNode->RefCount--; + } + free(Context.FileNodes); + + return Result; } static NTSTATUS GetSecurity(FSP_FILE_SYSTEM *FileSystem,