Add the ability to import into ArchiveModelManagers from a stream

This commit is contained in:
Dean Herbert 2020-12-07 18:00:45 +09:00
parent a5e2509d52
commit eb38bc4b4c
10 changed files with 132 additions and 43 deletions

View File

@ -14,6 +14,7 @@ using osu.Framework.Allocation;
using osu.Framework.Extensions;
using osu.Framework.Logging;
using osu.Game.Beatmaps;
using osu.Game.Database;
using osu.Game.IO;
using osu.Game.Rulesets.Osu;
using osu.Game.Rulesets.Osu.Objects;
@ -127,7 +128,7 @@ namespace osu.Game.Tests.Beatmaps.IO
// zip files differ because different compression or encoder.
Assert.AreNotEqual(hashBefore, hashFile(temp));
var importedSecondTime = await osu.Dependencies.Get<BeatmapManager>().Import(temp);
var importedSecondTime = await osu.Dependencies.Get<BeatmapManager>().Import(new ImportTask(temp));
ensureLoaded(osu);
@ -184,7 +185,7 @@ namespace osu.Game.Tests.Beatmaps.IO
zip.SaveTo(temp, new ZipWriterOptions(CompressionType.Deflate));
}
var importedSecondTime = await osu.Dependencies.Get<BeatmapManager>().Import(temp);
var importedSecondTime = await osu.Dependencies.Get<BeatmapManager>().Import(new ImportTask(temp));
ensureLoaded(osu);
@ -235,7 +236,7 @@ namespace osu.Game.Tests.Beatmaps.IO
zip.SaveTo(temp, new ZipWriterOptions(CompressionType.Deflate));
}
var importedSecondTime = await osu.Dependencies.Get<BeatmapManager>().Import(temp);
var importedSecondTime = await osu.Dependencies.Get<BeatmapManager>().Import(new ImportTask(temp));
ensureLoaded(osu);
@ -351,7 +352,7 @@ namespace osu.Game.Tests.Beatmaps.IO
// this will trigger purging of the existing beatmap (online set id match) but should rollback due to broken osu.
try
{
await manager.Import(breakTemp);
await manager.Import(new ImportTask(breakTemp));
}
catch
{
@ -614,7 +615,7 @@ namespace osu.Game.Tests.Beatmaps.IO
zip.SaveTo(temp, new ZipWriterOptions(CompressionType.Deflate));
}
var imported = await osu.Dependencies.Get<BeatmapManager>().Import(temp);
var imported = await osu.Dependencies.Get<BeatmapManager>().Import(new ImportTask(temp));
ensureLoaded(osu);
@ -667,7 +668,7 @@ namespace osu.Game.Tests.Beatmaps.IO
zip.SaveTo(temp, new ZipWriterOptions(CompressionType.Deflate));
}
var imported = await osu.Dependencies.Get<BeatmapManager>().Import(temp);
var imported = await osu.Dependencies.Get<BeatmapManager>().Import(new ImportTask(temp));
ensureLoaded(osu);
@ -821,7 +822,7 @@ namespace osu.Game.Tests.Beatmaps.IO
var manager = osu.Dependencies.Get<BeatmapManager>();
var importedSet = await manager.Import(temp);
var importedSet = await manager.Import(new ImportTask(temp));
ensureLoaded(osu);

View File

@ -6,6 +6,7 @@ using System.Linq;
using NUnit.Framework;
using osu.Framework.Testing;
using osu.Game.Beatmaps;
using osu.Game.Database;
using osu.Game.Input.Bindings;
using osu.Game.Overlays;
using osu.Game.Tests.Resources;
@ -52,7 +53,7 @@ namespace osu.Game.Tests.Visual.Menus
AddStep("import beatmap with track", () =>
{
var setWithTrack = Game.BeatmapManager.Import(TestResources.GetTestBeatmapForImport()).Result;
var setWithTrack = Game.BeatmapManager.Import(new ImportTask(TestResources.GetTestBeatmapForImport())).Result;
Beatmap.Value = Game.BeatmapManager.GetWorkingBeatmap(setWithTrack.Beatmaps.First());
});

View File

@ -42,7 +42,7 @@ namespace osu.Game.Tests.Visual.Online
ensureSoleilyRemoved();
createButtonWithBeatmap(createSoleily());
AddAssert("button state not downloaded", () => downloadButton.DownloadState == DownloadState.NotDownloaded);
AddStep("import soleily", () => beatmaps.Import(new[] { TestResources.GetTestBeatmapForImport() }));
AddStep("import soleily", () => beatmaps.Import(TestResources.GetTestBeatmapForImport()));
AddUntilStep("wait for beatmap import", () => beatmaps.GetAllUsableBeatmapSets().Any(b => b.OnlineBeatmapSetID == 241526));
createButtonWithBeatmap(createSoleily());
AddAssert("button state downloaded", () => downloadButton.DownloadState == DownloadState.LocallyAvailable);

View File

@ -12,6 +12,7 @@ using osu.Framework.Platform;
using osu.Framework.Testing;
using osu.Framework.Utils;
using osu.Game.Beatmaps;
using osu.Game.Database;
using osu.Game.Graphics.Cursor;
using osu.Game.Graphics.UserInterface;
using osu.Game.Online.Leaderboards;
@ -83,7 +84,7 @@ namespace osu.Game.Tests.Visual.UserInterface
dependencies.Cache(beatmapManager = new BeatmapManager(LocalStorage, ContextFactory, rulesetStore, null, dependencies.Get<AudioManager>(), dependencies.Get<GameHost>(), Beatmap.Default));
dependencies.Cache(scoreManager = new ScoreManager(rulesetStore, () => beatmapManager, LocalStorage, null, ContextFactory));
beatmap = beatmapManager.Import(TestResources.GetTestBeatmapForImport()).Result.Beatmaps[0];
beatmap = beatmapManager.Import(new ImportTask(TestResources.GetTestBeatmapForImport())).Result.Beatmaps[0];
for (int i = 0; i < 50; i++)
{

View File

@ -21,9 +21,7 @@ using osu.Game.IO;
using osu.Game.IO.Archives;
using osu.Game.IPC;
using osu.Game.Overlays.Notifications;
using osu.Game.Utils;
using SharpCompress.Archives.Zip;
using SharpCompress.Common;
using FileInfo = osu.Game.IO.FileInfo;
namespace osu.Game.Database
@ -114,10 +112,19 @@ namespace osu.Game.Database
PostNotification?.Invoke(notification);
return Import(notification, paths);
return Import(notification, paths.Select(p => new ImportTask(p)).ToArray());
}
protected async Task<IEnumerable<TModel>> Import(ProgressNotification notification, params string[] paths)
public Task Import(Stream stream, string filename)
{
var notification = new ProgressNotification { State = ProgressNotificationState.Active };
PostNotification?.Invoke(notification);
return Import(notification, new ImportTask(stream, filename));
}
protected async Task<IEnumerable<TModel>> Import(ProgressNotification notification, params ImportTask[] tasks)
{
notification.Progress = 0;
notification.Text = $"{HumanisedModelName.Humanize(LetterCasing.Title)} import is initialising...";
@ -126,13 +133,13 @@ namespace osu.Game.Database
var imported = new List<TModel>();
await Task.WhenAll(paths.Select(async path =>
await Task.WhenAll(tasks.Select(async task =>
{
notification.CancellationToken.ThrowIfCancellationRequested();
try
{
var model = await Import(path, notification.CancellationToken);
var model = await Import(task, notification.CancellationToken);
lock (imported)
{
@ -140,8 +147,8 @@ namespace osu.Game.Database
imported.Add(model);
current++;
notification.Text = $"Imported {current} of {paths.Length} {HumanisedModelName}s";
notification.Progress = (float)current / paths.Length;
notification.Text = $"Imported {current} of {tasks.Length} {HumanisedModelName}s";
notification.Progress = (float)current / tasks.Length;
}
}
catch (TaskCanceledException)
@ -150,7 +157,7 @@ namespace osu.Game.Database
}
catch (Exception e)
{
Logger.Error(e, $@"Could not import ({Path.GetFileName(path)})", LoggingTarget.Database);
Logger.Error(e, $@"Could not import ({task})", LoggingTarget.Database);
}
}));
@ -183,16 +190,17 @@ namespace osu.Game.Database
/// <summary>
/// Import one <typeparamref name="TModel"/> from the filesystem and delete the file on success.
/// Note that this bypasses the UI flow and should only be used for special cases or testing.
/// </summary>
/// <param name="path">The archive location on disk.</param>
/// <param name="task">The archive location on disk.</param>
/// <param name="cancellationToken">An optional cancellation token.</param>
/// <returns>The imported model, if successful.</returns>
public async Task<TModel> Import(string path, CancellationToken cancellationToken = default)
internal async Task<TModel> Import(ImportTask task, CancellationToken cancellationToken = default)
{
cancellationToken.ThrowIfCancellationRequested();
TModel import;
using (ArchiveReader reader = getReaderFrom(path))
using (ArchiveReader reader = task.GetReader())
import = await Import(reader, cancellationToken);
// We may or may not want to delete the file depending on where it is stored.
@ -201,12 +209,12 @@ namespace osu.Game.Database
// TODO: Add a check to prevent files from storage to be deleted.
try
{
if (import != null && File.Exists(path) && ShouldDeleteArchive(path))
File.Delete(path);
if (import != null && File.Exists(task.Path) && ShouldDeleteArchive(task.Path))
File.Delete(task.Path);
}
catch (Exception e)
{
LogForModel(import, $@"Could not delete original file after import ({Path.GetFileName(path)})", e);
LogForModel(import, $@"Could not delete original file after import ({task})", e);
}
return import;
@ -727,23 +735,6 @@ namespace osu.Game.Database
protected virtual string HumanisedModelName => $"{typeof(TModel).Name.Replace("Info", "").ToLower()}";
/// <summary>
/// Creates an <see cref="ArchiveReader"/> from a valid storage path.
/// </summary>
/// <param name="path">A file or folder path resolving the archive content.</param>
/// <returns>A reader giving access to the archive's content.</returns>
private ArchiveReader getReaderFrom(string path)
{
if (ZipUtils.IsZipArchive(path))
return new ZipArchiveReader(File.Open(path, FileMode.Open, FileAccess.Read, FileShare.Read), Path.GetFileName(path));
if (Directory.Exists(path))
return new LegacyDirectoryArchiveReader(path);
if (File.Exists(path))
return new LegacyFileArchiveReader(path);
throw new InvalidFormatException($"{path} is not a valid archive");
}
#region Event handling / delaying
private readonly List<Action> queuedEvents = new List<Action>();

View File

@ -82,7 +82,7 @@ namespace osu.Game.Database
Task.Factory.StartNew(async () =>
{
// This gets scheduled back to the update thread, but we want the import to run in the background.
var imported = await Import(notification, filename);
var imported = await Import(notification, new ImportTask(filename));
// for now a failed import will be marked as a failed download for simplicity.
if (!imported.Any())

View File

@ -2,6 +2,7 @@
// See the LICENCE file in the repository root for full licence text.
using System.Collections.Generic;
using System.IO;
using System.Threading.Tasks;
namespace osu.Game.Database
@ -17,6 +18,13 @@ namespace osu.Game.Database
/// <param name="paths">The files which should be imported.</param>
Task Import(params string[] paths);
/// <summary>
/// Import the provided stream as a simple item.
/// </summary>
/// <param name="stream">The stream to import files from. Should be in a supported archive format.</param>
/// <param name="filename">The filename of the archive being imported.</param>
Task Import(Stream stream, string filename);
/// <summary>
/// An array of accepted file extensions (in the standard format of ".abc").
/// </summary>

View File

@ -0,0 +1,73 @@
// Copyright (c) ppy Pty Ltd <contact@ppy.sh>. Licensed under the MIT Licence.
// See the LICENCE file in the repository root for full licence text.
using System.IO;
using osu.Game.IO.Archives;
using osu.Game.Utils;
using SharpCompress.Common;
namespace osu.Game.Database
{
/// <summary>
/// An encapsulated import task to be imported to an <see cref="ArchiveModelManager{TModel,TFileModel}"/>.
/// </summary>
public class ImportTask
{
/// <summary>
/// The path to the file (or filename in the case a stream is provided).
/// </summary>
public string Path { get; }
/// <summary>
/// An optional stream which provides the file content.
/// </summary>
public Stream Stream { get; }
/// <summary>
/// Construct a new import task from a path (on a local filesystem).
/// </summary>
public ImportTask(string path)
{
Path = path;
}
/// <summary>
/// Construct a new import task from a stream.
/// </summary>
public ImportTask(Stream stream, string filename)
{
Path = filename;
Stream = stream;
}
/// <summary>
/// Retrieve an archive reader from this task.
/// </summary>
public ArchiveReader GetReader()
{
if (Stream != null)
return new ZipArchiveReader(Stream, Path);
return getReaderFrom(Path);
}
/// <summary>
/// Creates an <see cref="ArchiveReader"/> from a valid storage path.
/// </summary>
/// <param name="path">A file or folder path resolving the archive content.</param>
/// <returns>A reader giving access to the archive's content.</returns>
private ArchiveReader getReaderFrom(string path)
{
if (ZipUtils.IsZipArchive(path))
return new ZipArchiveReader(File.Open(path, FileMode.Open, FileAccess.Read, FileShare.Read), System.IO.Path.GetFileName(path));
if (Directory.Exists(path))
return new LegacyDirectoryArchiveReader(path);
if (File.Exists(path))
return new LegacyFileArchiveReader(path);
throw new InvalidFormatException($"{path} is not a valid archive");
}
public override string ToString() => System.IO.Path.GetFileName(Path);
}
}

View File

@ -395,6 +395,17 @@ namespace osu.Game
}
}
public async Task Import(Stream stream, string filename)
{
var extension = Path.GetExtension(filename)?.ToLowerInvariant();
foreach (var importer in fileImporters)
{
if (importer.HandledExtensions.Contains(extension))
await importer.Import(stream, Path.GetFileNameWithoutExtension(filename));
}
}
public IEnumerable<string> HandledExtensions => fileImporters.SelectMany(i => i.HandledExtensions);
protected override void Dispose(bool isDisposing)

View File

@ -1,6 +1,7 @@
// Copyright (c) ppy Pty Ltd <contact@ppy.sh>. Licensed under the MIT Licence.
// See the LICENCE file in the repository root for full licence text.
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
@ -99,6 +100,8 @@ namespace osu.Game.Screens.Edit.Setup
return Task.CompletedTask;
}
Task ICanAcceptFiles.Import(Stream stream, string filename) => throw new NotImplementedException();
protected override void LoadComplete()
{
base.LoadComplete();