// Copyright (C) Stichting Deltares 2016. All rights reserved. // // This file is part of Ringtoets. // // Ringtoets is free software: you can redistribute it and/or modify // it under the terms of the GNU General Public License as published by // the Free Software Foundation, either version 3 of the License, or // (at your option) any later version. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU General Public License for more details. // // You should have received a copy of the GNU General Public License // along with this program. If not, see . // // All names, logos, and references to "Deltares" are registered trademarks of // Stichting Deltares and remain full property of Stichting Deltares at all times. // All rights reserved. using System; using System.Collections; using System.Collections.Generic; using System.Collections.ObjectModel; using System.Data.Entity; using System.Linq; using System.Linq.Expressions; namespace Application.Ringtoets.Storage.TestUtil { public sealed class TestDbSet : DbSet, IDbSet where T : class { private readonly IQueryable queryable; private readonly ObservableCollection collection; public TestDbSet(ObservableCollection queryable) { collection = queryable; this.queryable = queryable.AsQueryable(); } public IQueryProvider Provider { get { return queryable.Provider; } } public Expression Expression { get { return queryable.Expression; } } public Type ElementType { get { return queryable.ElementType; } } public override ObservableCollection Local { get { return collection; } } public override IEnumerable RemoveRange(IEnumerable entities) { var list = entities.ToList(); foreach (var e in list) { collection.Remove(e); } return list; } public override T Add(T entity) { collection.Add(entity); return entity; } public override T Remove(T entity) { collection.Remove(entity); return entity; } IEnumerator IEnumerable.GetEnumerator() { return collection.GetEnumerator(); } IEnumerator IEnumerable.GetEnumerator() { return collection.GetEnumerator(); } public override T Find(params object[] keyValues) { var propertyInfo = typeof(T).GetProperty(typeof(T).Name + "Id"); if (propertyInfo == null) { // Naming convention: Primary key of an entity should be named '+Id', but convention is violated: throw new MissingMemberException(typeof(T).Name, typeof(T).Name + "Id"); } return collection.SingleOrDefault(i => propertyInfo.GetValue(i, null).Equals(keyValues[0])); } } }