|
| 1 | +using RefactorThis.GraphDiff.Internal.Graph; |
1 | 2 | using System; |
2 | 3 | using System.Collections.Generic; |
3 | 4 | using System.Data.Entity; |
4 | | -using System.Data.Entity.Core; |
5 | | -using System.Data.Entity.Core.Metadata.Edm; |
6 | | -using System.Data.Entity.Core.Objects; |
7 | 5 | using System.Data.Entity.Infrastructure; |
8 | 6 | using System.Linq; |
| 7 | +using System.Linq.Expressions; |
9 | 8 | using System.Reflection; |
10 | 9 |
|
11 | | -namespace RefactorThis.GraphDiff.Internal.Graph |
| 10 | +namespace RefactorThis.GraphDiff.Internal |
12 | 11 | { |
13 | | - internal class GraphNode |
| 12 | + /// <summary> |
| 13 | + /// GraphDiff main access point. |
| 14 | + /// </summary> |
| 15 | + /// <typeparam name="T">The root agreggate type</typeparam> |
| 16 | + internal class GraphDiffer<T> where T : class, new() |
14 | 17 | { |
15 | | - #region Fields, Properties Constructors |
| 18 | + private readonly GraphNode _root; |
16 | 19 |
|
17 | | - public GraphNode Parent { get; private set; } |
18 | | - public Stack<GraphNode> Members { get; private set; } |
19 | | - |
20 | | - protected readonly PropertyInfo Accessor; |
21 | | - |
22 | | - protected string IncludeString |
23 | | - { |
24 | | - get |
25 | | - { |
26 | | - var ownIncludeString = Accessor != null ? Accessor.Name : null; |
27 | | - return Parent != null && Parent.IncludeString != null |
28 | | - ? Parent.IncludeString + "." + ownIncludeString |
29 | | - : ownIncludeString; |
30 | | - } |
31 | | - } |
32 | | - |
33 | | - public GraphNode() |
34 | | - { |
35 | | - Members = new Stack<GraphNode>(); |
36 | | - } |
37 | | - |
38 | | - protected GraphNode(GraphNode parent, PropertyInfo accessor) |
| 20 | + public GraphDiffer(GraphNode root) |
39 | 21 | { |
40 | | - Accessor = accessor; |
41 | | - Members = new Stack<GraphNode>(); |
42 | | - Parent = parent; |
| 22 | + _root = root; |
43 | 23 | } |
44 | 24 |
|
45 | | - #endregion |
46 | | - |
47 | | - // overridden by different implementations |
48 | | - public virtual void Update<T>(DbContext context, T persisted, T updating) where T : class, new() |
| 25 | + public T Merge(DbContext context, T updating) |
49 | 26 | { |
50 | | - UpdateValuesWithConcurrencyCheck(context, updating, persisted); |
51 | | - |
52 | | - // Foreach branch perform recursive update |
53 | | - foreach (var member in Members) |
| 27 | + bool isAutoDetectEnabled = context.Configuration.AutoDetectChangesEnabled; |
| 28 | + try |
54 | 29 | { |
55 | | - member.Update(context, persisted, updating); |
56 | | - } |
57 | | - } |
| 30 | + // performance improvement for large graphs |
| 31 | + context.Configuration.AutoDetectChangesEnabled = false; |
58 | 32 |
|
59 | | - protected T GetValue<T>(object instance) |
60 | | - { |
61 | | - return (T)Accessor.GetValue(instance, null); |
62 | | - } |
63 | | - |
64 | | - protected void SetValue(object instance, object value) |
65 | | - { |
66 | | - Accessor.SetValue(instance, value, null); |
67 | | - } |
| 33 | + // Get our entity with all includes needed, or add |
| 34 | + T persisted = GetOrAddPersistedEntity(context, updating); |
68 | 35 |
|
69 | | - protected static EntityKey CreateEntityKey(IObjectContextAdapter context, object entity) |
70 | | - { |
71 | | - if (entity == null) |
72 | | - { |
73 | | - throw new ArgumentNullException("entity"); |
74 | | - } |
| 36 | + if (context.Entry(updating).State != EntityState.Detached) |
| 37 | + { |
| 38 | + throw new InvalidOperationException("GraphDiff supports detached entities only at this time. Please try AsNoTracking() or detach your entites before calling the UpdateGraph method"); |
| 39 | + } |
75 | 40 |
|
76 | | - return context.ObjectContext.CreateEntityKey(context.GetEntitySetName(entity.GetType()), entity); |
77 | | - } |
| 41 | + // Perform recursive update |
| 42 | + _root.Update(context, persisted, updating); |
78 | 43 |
|
79 | | - internal void GetIncludeStrings(DbContext context, List<string> includeStrings) |
80 | | - { |
81 | | - var ownIncludeString = IncludeString; |
82 | | - if (!string.IsNullOrEmpty(ownIncludeString)) |
83 | | - { |
84 | | - includeStrings.Add(ownIncludeString); |
| 44 | + return persisted; |
85 | 45 | } |
86 | | - |
87 | | - includeStrings.AddRange(GetRequiredNavigationPropertyIncludes(context)); |
88 | | - |
89 | | - foreach (var member in Members) |
| 46 | + finally |
90 | 47 | { |
91 | | - member.GetIncludeStrings(context, includeStrings); |
| 48 | + context.Configuration.AutoDetectChangesEnabled = isAutoDetectEnabled; |
92 | 49 | } |
93 | 50 | } |
94 | 51 |
|
95 | | - protected virtual IEnumerable<string> GetRequiredNavigationPropertyIncludes(DbContext context) |
96 | | - { |
97 | | - return new string[0]; |
98 | | - } |
99 | | - |
100 | | - protected static IEnumerable<string> GetRequiredNavigationPropertyIncludes(DbContext context, Type entityType, string ownIncludeString) |
101 | | - { |
102 | | - return context.GetRequiredNavigationPropertiesForType(entityType) |
103 | | - .Select(navigationProperty => ownIncludeString + "." + navigationProperty.Name); |
104 | | - } |
105 | | - |
106 | | - protected static void AttachCyclicNavigationProperty(IObjectContextAdapter context, object parent, object child) |
| 52 | + private T GetOrAddPersistedEntity(DbContext context, T entity) |
107 | 53 | { |
108 | | - if (parent == null || child == null) return; |
109 | | - |
110 | | - var parentType = ObjectContext.GetObjectType(parent.GetType()); |
111 | | - var childType = ObjectContext.GetObjectType(child.GetType()); |
112 | | - |
113 | | - var navigationProperties = context.GetNavigationPropertiesForType(childType); |
114 | | - |
115 | | - var parentNavigationProperty = navigationProperties |
116 | | - .Where(navigation => navigation.TypeUsage.EdmType.Name == parentType.Name) |
117 | | - .Select(navigation => childType.GetProperty(navigation.Name, BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.Public)) |
118 | | - .FirstOrDefault(); |
119 | | - |
120 | | - if (parentNavigationProperty != null) |
121 | | - parentNavigationProperty.SetValue(child, parent, null); |
122 | | - } |
123 | | - |
124 | | - protected static void UpdateValuesWithConcurrencyCheck<T>(DbContext context, T from, T to) where T : class |
125 | | - { |
126 | | - if (context.Entry(to).State != EntityState.Added) |
| 54 | + if (entity == null) |
127 | 55 | { |
128 | | - EnsureConcurrency(context, from, to); |
| 56 | + throw new ArgumentNullException("entity"); |
129 | 57 | } |
130 | 58 |
|
131 | | - context.Entry(to).CurrentValues.SetValues(from); |
132 | | - } |
133 | | - |
134 | | - protected static object AttachAndReloadAssociatedEntity(DbContext context, object entity) |
135 | | - { |
136 | | - var localCopy = FindLocalByKey(context, entity); |
137 | | - if (localCopy != null) return localCopy; |
| 59 | + var persisted = FindEntityMatching(context, entity); |
138 | 60 |
|
139 | | - if (context.Entry(entity).State == EntityState.Detached) |
| 61 | + if (persisted == null) |
140 | 62 | { |
141 | | - var entityType = ObjectContext.GetObjectType(entity.GetType()); |
142 | | - var instance = CreateEmptyEntityWithKey(context, entity); |
143 | | - |
144 | | - context.Set(entityType).Attach(instance); |
145 | | - context.Entry(instance).Reload(); |
146 | | - |
147 | | - AttachRequiredNavigationProperties(context, entity, instance); |
148 | | - return instance; |
| 63 | + // we are always working with 2 graphs, simply add a 'persisted' one if none exists, |
| 64 | + // this ensures that only the changes we make within the bounds of the mapping are attempted. |
| 65 | + persisted = new T(); |
| 66 | + context.Set<T>().Add(persisted); |
149 | 67 | } |
150 | 68 |
|
151 | | - if (GraphDiffConfiguration.ReloadAssociatedEntitiesWhenAttached) |
152 | | - { |
153 | | - context.Entry(entity).Reload(); |
154 | | - } |
155 | | - |
156 | | - return entity; |
157 | | - } |
158 | | - |
159 | | - private static object FindLocalByKey(DbContext context, object entity) |
160 | | - { |
161 | | - var eType = ObjectContext.GetObjectType(entity.GetType()); |
162 | | - return context.Set(eType).Local.OfType<object>().FirstOrDefault(local => IsKeyIdentical(context, local, entity)); |
| 69 | + return persisted; |
163 | 70 | } |
164 | 71 |
|
165 | | - protected static void AttachRequiredNavigationProperties(DbContext context, object updating, object persisted) |
| 72 | + private T FindEntityMatching(DbContext context, T entity) |
166 | 73 | { |
167 | | - var entityType = ObjectContext.GetObjectType(updating.GetType()); |
168 | | - foreach (var navigationProperty in context.GetRequiredNavigationPropertiesForType(updating.GetType())) |
169 | | - { |
170 | | - var navigationPropertyInfo = entityType.GetProperty(navigationProperty.Name, BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.Public); |
| 74 | + var includeStrings = new List<string>(); |
| 75 | + _root.GetIncludeStrings(context, includeStrings); |
171 | 76 |
|
172 | | - var associatedEntity = navigationPropertyInfo.GetValue(updating, null); |
173 | | - if (associatedEntity != null) |
174 | | - { |
175 | | - associatedEntity = FindEntityByKey(context, associatedEntity); |
176 | | - } |
177 | | - |
178 | | - navigationPropertyInfo.SetValue(persisted, associatedEntity, null); |
179 | | - } |
180 | | - } |
181 | | - |
182 | | - private static object FindEntityByKey(DbContext context, object associatedEntity) |
183 | | - { |
184 | | - var associatedEntityType = ObjectContext.GetObjectType(associatedEntity.GetType()); |
185 | | - var keyFields = context.GetPrimaryKeyFieldsFor(associatedEntityType); |
186 | | - var keys = keyFields.Select(key => key.GetValue(associatedEntity, null)).ToArray(); |
187 | | - return context.Set(associatedEntityType).Find(keys); |
188 | | - } |
| 77 | + // attach includes to IQueryable |
| 78 | + var query = context.Set<T>().AsQueryable(); |
| 79 | + query = includeStrings.Aggregate(query, (current, include) => current.Include(include)); |
189 | 80 |
|
190 | | - protected static object CreateEmptyEntityWithKey(IObjectContextAdapter context, object entity) |
191 | | - { |
192 | | - var instance = Activator.CreateInstance(entity.GetType()); |
193 | | - CopyPrimaryKeyFields(context, entity, instance); |
194 | | - return instance; |
| 81 | + // Run the find operation |
| 82 | + return query.SingleOrDefault(CreateKeyPredicateExpression(context, entity)); |
195 | 83 | } |
196 | 84 |
|
197 | | - private static void CopyPrimaryKeyFields(IObjectContextAdapter context, object from, object to) |
| 85 | + private static Expression<Func<T, bool>> CreateKeyPredicateExpression(IObjectContextAdapter context, T entity) |
198 | 86 | { |
199 | | - var keyProperties = context.GetPrimaryKeyFieldsFor(from.GetType()).ToList(); |
200 | | - foreach (var keyProperty in keyProperties) |
201 | | - keyProperty.SetValue(to, keyProperty.GetValue(from, null), null); |
202 | | - } |
| 87 | + // get key properties of T |
| 88 | + var keyProperties = context.GetPrimaryKeyFieldsFor(typeof(T)).ToList(); |
203 | 89 |
|
204 | | - protected static bool IsKeyIdentical(DbContext context, object newValue, object dbValue) |
205 | | - { |
206 | | - if (newValue == null || dbValue == null) return false; |
| 90 | + ParameterExpression parameter = Expression.Parameter(typeof(T)); |
| 91 | + Expression expression = CreateEqualsExpression(entity, keyProperties[0], parameter); |
| 92 | + for (int i = 1; i < keyProperties.Count; i++) |
| 93 | + expression = Expression.And(expression, CreateEqualsExpression(entity, keyProperties[i], parameter)); |
207 | 94 |
|
208 | | - return CreateEntityKey(context, newValue) == CreateEntityKey(context, dbValue); |
| 95 | + return Expression.Lambda<Func<T, bool>>(expression, parameter); |
209 | 96 | } |
210 | 97 |
|
211 | | - private static void EnsureConcurrency<T>(IObjectContextAdapter db, T entity1, T entity2) |
| 98 | + private static Expression CreateEqualsExpression(object entity, PropertyInfo keyProperty, Expression parameter) |
212 | 99 | { |
213 | | - // get concurrency properties of T |
214 | | - var entityType = ObjectContext.GetObjectType(entity1.GetType()); |
215 | | - var metadata = db.ObjectContext.MetadataWorkspace; |
216 | | - |
217 | | - var objType = metadata.GetEntityTypeByType(entityType); |
218 | | - |
219 | | - // need internal string, code smells bad.. any better way to do this? |
220 | | - var cTypeName = (string)objType.GetType() |
221 | | - .GetProperty("CSpaceTypeName", BindingFlags.Instance | BindingFlags.NonPublic) |
222 | | - .GetValue(objType, null); |
223 | | - |
224 | | - var conceptualType = metadata.GetItems<EntityType>(DataSpace.CSpace).Single(p => p.FullName == cTypeName); |
225 | | - var concurrencyProperties = conceptualType.Members |
226 | | - .Where(member => member.TypeUsage.Facets.Any(facet => facet.Name == "ConcurrencyMode" && (ConcurrencyMode)facet.Value == ConcurrencyMode.Fixed)) |
227 | | - .Select(member => entityType.GetProperty(member.Name, BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.Public)) |
228 | | - .ToList(); |
229 | | - |
230 | | - // Check if concurrency properties are equal |
231 | | - // TODO EF should do this automatically should it not? |
232 | | - foreach (var concurrencyProp in concurrencyProperties) |
233 | | - { |
234 | | - // if is byte[] use array comparison, else equals(). |
235 | | - |
236 | | - var type = concurrencyProp.PropertyType; |
237 | | - var obj1 = concurrencyProp.GetValue(entity1, null); |
238 | | - var obj2 = concurrencyProp.GetValue(entity2, null); |
239 | | - |
240 | | - if ( |
241 | | - (obj1 == null || obj2 == null) || |
242 | | - (type == typeof (byte[]) && !((byte[]) obj1).SequenceEqual((byte[]) obj2)) || |
243 | | - (type != typeof (byte[]) && !obj1.Equals(obj2)) |
244 | | - ) |
245 | | - { |
246 | | - throw new DbUpdateConcurrencyException(String.Format("{0} failed optimistic concurrency", concurrencyProp.Name)); |
247 | | - } |
248 | | - } |
| 100 | + return Expression.Equal(Expression.Property(parameter, keyProperty), Expression.Constant(keyProperty.GetValue(entity, null), keyProperty.PropertyType)); |
249 | 101 | } |
250 | 102 | } |
251 | 103 | } |
0 commit comments