Skip to content

Commit bc73032

Browse files
authored
add extended methode
add extended methode
1 parent 30d6a74 commit bc73032

File tree

1 file changed

+204
-56
lines changed

1 file changed

+204
-56
lines changed
Lines changed: 204 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,103 +1,251 @@
1-
using RefactorThis.GraphDiff.Internal.Graph;
21
using System;
32
using System.Collections.Generic;
43
using System.Data.Entity;
4+
using System.Data.Entity.Core;
5+
using System.Data.Entity.Core.Metadata.Edm;
6+
using System.Data.Entity.Core.Objects;
57
using System.Data.Entity.Infrastructure;
68
using System.Linq;
7-
using System.Linq.Expressions;
89
using System.Reflection;
910

10-
namespace RefactorThis.GraphDiff.Internal
11+
namespace RefactorThis.GraphDiff.Internal.Graph
1112
{
12-
/// <summary>
13-
/// GraphDiff main access point.
14-
/// </summary>
15-
/// <typeparam name="T">The root agreggate type</typeparam>
16-
internal class GraphDiffer<T> where T : class, new()
13+
internal class GraphNode
1714
{
18-
private readonly GraphNode _root;
15+
#region Fields, Properties Constructors
1916

20-
public GraphDiffer(GraphNode root)
17+
public GraphNode Parent { get; private set; }
18+
public Stack<GraphNode> Members { get; private set; }
19+
20+
protected readonly PropertyInfo Accessor;
21+
22+
protected string IncludeString
2123
{
22-
_root = root;
24+
get
25+
{
26+
var ownIncludeString = Accessor != null ? Accessor.Name : null;
27+
return Parent != null && Parent.IncludeString != null
28+
? Parent.IncludeString + "." + ownIncludeString
29+
: ownIncludeString;
30+
}
2331
}
2432

25-
public T Merge(DbContext context, T updating)
33+
public GraphNode()
2634
{
27-
bool isAutoDetectEnabled = context.Configuration.AutoDetectChangesEnabled;
28-
try
29-
{
30-
// performance improvement for large graphs
31-
context.Configuration.AutoDetectChangesEnabled = false;
35+
Members = new Stack<GraphNode>();
36+
}
3237

33-
// Get our entity with all includes needed, or add
34-
T persisted = GetOrAddPersistedEntity(context, updating);
38+
protected GraphNode(GraphNode parent, PropertyInfo accessor)
39+
{
40+
Accessor = accessor;
41+
Members = new Stack<GraphNode>();
42+
Parent = parent;
43+
}
3544

36-
if (context.Entry(updating).State != EntityState.Detached)
37-
{
38-
throw new InvalidOperationException("GraphDiff supports detached entities only at this time. Please try AsNoTracking() or detach your entites before calling the UpdateGraph method");
39-
}
45+
#endregion
4046

41-
// Perform recursive update
42-
_root.Update(context, persisted, updating);
47+
// overridden by different implementations
48+
public virtual void Update<T>(DbContext context, T persisted, T updating) where T : class, new()
49+
{
50+
UpdateValuesWithConcurrencyCheck(context, updating, persisted);
4351

44-
return persisted;
45-
}
46-
finally
52+
// Foreach branch perform recursive update
53+
foreach (var member in Members)
4754
{
48-
context.Configuration.AutoDetectChangesEnabled = isAutoDetectEnabled;
55+
member.Update(context, persisted, updating);
4956
}
5057
}
5158

52-
private T GetOrAddPersistedEntity(DbContext context, T entity)
59+
protected T GetValue<T>(object instance)
60+
{
61+
return (T)Accessor.GetValue(instance, null);
62+
}
63+
64+
protected void SetValue(object instance, object value)
65+
{
66+
Accessor.SetValue(instance, value, null);
67+
}
68+
69+
protected static EntityKey CreateEntityKey(IObjectContextAdapter context, object entity)
5370
{
5471
if (entity == null)
5572
{
5673
throw new ArgumentNullException("entity");
5774
}
5875

59-
var persisted = FindEntityMatching(context, entity);
76+
return context.ObjectContext.CreateEntityKey(context.GetEntitySetName(entity.GetType()), entity);
77+
}
78+
79+
internal void GetIncludeStrings(DbContext context, List<string> includeStrings)
80+
{
81+
var ownIncludeString = IncludeString;
82+
if (!string.IsNullOrEmpty(ownIncludeString))
83+
{
84+
includeStrings.Add(ownIncludeString);
85+
}
86+
87+
includeStrings.AddRange(GetRequiredNavigationPropertyIncludes(context));
88+
89+
foreach (var member in Members)
90+
{
91+
member.GetIncludeStrings(context, includeStrings);
92+
}
93+
}
94+
95+
protected virtual IEnumerable<string> GetRequiredNavigationPropertyIncludes(DbContext context)
96+
{
97+
return new string[0];
98+
}
99+
100+
protected static IEnumerable<string> GetRequiredNavigationPropertyIncludes(DbContext context, Type entityType, string ownIncludeString)
101+
{
102+
return context.GetRequiredNavigationPropertiesForType(entityType)
103+
.Select(navigationProperty => ownIncludeString + "." + navigationProperty.Name);
104+
}
105+
106+
protected static void AttachCyclicNavigationProperty(IObjectContextAdapter context, object parent, object child)
107+
{
108+
if (parent == null || child == null) return;
109+
110+
var parentType = ObjectContext.GetObjectType(parent.GetType());
111+
var childType = ObjectContext.GetObjectType(child.GetType());
112+
113+
var navigationProperties = context.GetNavigationPropertiesForType(childType);
114+
115+
var parentNavigationProperty = navigationProperties
116+
.Where(navigation => navigation.TypeUsage.EdmType.Name == parentType.Name)
117+
.Select(navigation => childType.GetProperty(navigation.Name, BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.Public))
118+
.FirstOrDefault();
119+
120+
if (parentNavigationProperty != null)
121+
parentNavigationProperty.SetValue(child, parent, null);
122+
}
123+
124+
protected static void UpdateValuesWithConcurrencyCheck<T>(DbContext context, T from, T to) where T : class
125+
{
126+
if (context.Entry(to).State != EntityState.Added)
127+
{
128+
EnsureConcurrency(context, from, to);
129+
}
130+
131+
context.Entry(to).CurrentValues.SetValues(from);
132+
}
133+
134+
protected static object AttachAndReloadAssociatedEntity(DbContext context, object entity)
135+
{
136+
var localCopy = FindLocalByKey(context, entity);
137+
if (localCopy != null) return localCopy;
60138

61-
if (persisted == null)
139+
if (context.Entry(entity).State == EntityState.Detached)
62140
{
63-
// we are always working with 2 graphs, simply add a 'persisted' one if none exists,
64-
// this ensures that only the changes we make within the bounds of the mapping are attempted.
65-
persisted = new T();
66-
context.Set<T>().Add(persisted);
141+
var entityType = ObjectContext.GetObjectType(entity.GetType());
142+
var instance = CreateEmptyEntityWithKey(context, entity);
143+
144+
context.Set(entityType).Attach(instance);
145+
context.Entry(instance).Reload();
146+
147+
AttachRequiredNavigationProperties(context, entity, instance);
148+
return instance;
67149
}
68150

69-
return persisted;
151+
if (GraphDiffConfiguration.ReloadAssociatedEntitiesWhenAttached)
152+
{
153+
context.Entry(entity).Reload();
154+
}
155+
156+
return entity;
157+
}
158+
159+
private static object FindLocalByKey(DbContext context, object entity)
160+
{
161+
var eType = ObjectContext.GetObjectType(entity.GetType());
162+
return context.Set(eType).Local.OfType<object>().FirstOrDefault(local => IsKeyIdentical(context, local, entity));
70163
}
71164

72-
private T FindEntityMatching(DbContext context, T entity)
165+
protected static void AttachRequiredNavigationProperties(DbContext context, object updating, object persisted)
73166
{
74-
var includeStrings = new List<string>();
75-
_root.GetIncludeStrings(context, includeStrings);
167+
var entityType = ObjectContext.GetObjectType(updating.GetType());
168+
foreach (var navigationProperty in context.GetRequiredNavigationPropertiesForType(updating.GetType()))
169+
{
170+
var navigationPropertyInfo = entityType.GetProperty(navigationProperty.Name, BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.Public);
76171

77-
// attach includes to IQueryable
78-
var query = context.Set<T>().AsQueryable();
79-
query = includeStrings.Aggregate(query, (current, include) => current.Include(include));
172+
var associatedEntity = navigationPropertyInfo.GetValue(updating, null);
173+
if (associatedEntity != null)
174+
{
175+
associatedEntity = FindEntityByKey(context, associatedEntity);
176+
}
177+
178+
navigationPropertyInfo.SetValue(persisted, associatedEntity, null);
179+
}
180+
}
181+
182+
private static object FindEntityByKey(DbContext context, object associatedEntity)
183+
{
184+
var associatedEntityType = ObjectContext.GetObjectType(associatedEntity.GetType());
185+
var keyFields = context.GetPrimaryKeyFieldsFor(associatedEntityType);
186+
var keys = keyFields.Select(key => key.GetValue(associatedEntity, null)).ToArray();
187+
return context.Set(associatedEntityType).Find(keys);
188+
}
80189

81-
// Run the find operation
82-
return query.SingleOrDefault(CreateKeyPredicateExpression(context, entity));
190+
protected static object CreateEmptyEntityWithKey(IObjectContextAdapter context, object entity)
191+
{
192+
var instance = Activator.CreateInstance(entity.GetType());
193+
CopyPrimaryKeyFields(context, entity, instance);
194+
return instance;
83195
}
84196

85-
private static Expression<Func<T, bool>> CreateKeyPredicateExpression(IObjectContextAdapter context, T entity)
197+
private static void CopyPrimaryKeyFields(IObjectContextAdapter context, object from, object to)
86198
{
87-
// get key properties of T
88-
var keyProperties = context.GetPrimaryKeyFieldsFor(typeof(T)).ToList();
199+
var keyProperties = context.GetPrimaryKeyFieldsFor(from.GetType()).ToList();
200+
foreach (var keyProperty in keyProperties)
201+
keyProperty.SetValue(to, keyProperty.GetValue(from, null), null);
202+
}
89203

90-
ParameterExpression parameter = Expression.Parameter(typeof(T));
91-
Expression expression = CreateEqualsExpression(entity, keyProperties[0], parameter);
92-
for (int i = 1; i < keyProperties.Count; i++)
93-
expression = Expression.And(expression, CreateEqualsExpression(entity, keyProperties[i], parameter));
204+
protected static bool IsKeyIdentical(DbContext context, object newValue, object dbValue)
205+
{
206+
if (newValue == null || dbValue == null) return false;
94207

95-
return Expression.Lambda<Func<T, bool>>(expression, parameter);
208+
return CreateEntityKey(context, newValue) == CreateEntityKey(context, dbValue);
96209
}
97210

98-
private static Expression CreateEqualsExpression(object entity, PropertyInfo keyProperty, Expression parameter)
211+
private static void EnsureConcurrency<T>(IObjectContextAdapter db, T entity1, T entity2)
99212
{
100-
return Expression.Equal(Expression.Property(parameter, keyProperty), Expression.Constant(keyProperty.GetValue(entity, null), keyProperty.PropertyType));
213+
// get concurrency properties of T
214+
var entityType = ObjectContext.GetObjectType(entity1.GetType());
215+
var metadata = db.ObjectContext.MetadataWorkspace;
216+
217+
var objType = metadata.GetEntityTypeByType(entityType);
218+
219+
// need internal string, code smells bad.. any better way to do this?
220+
var cTypeName = (string)objType.GetType()
221+
.GetProperty("CSpaceTypeName", BindingFlags.Instance | BindingFlags.NonPublic)
222+
.GetValue(objType, null);
223+
224+
var conceptualType = metadata.GetItems<EntityType>(DataSpace.CSpace).Single(p => p.FullName == cTypeName);
225+
var concurrencyProperties = conceptualType.Members
226+
.Where(member => member.TypeUsage.Facets.Any(facet => facet.Name == "ConcurrencyMode" && (ConcurrencyMode)facet.Value == ConcurrencyMode.Fixed))
227+
.Select(member => entityType.GetProperty(member.Name, BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.Public))
228+
.ToList();
229+
230+
// Check if concurrency properties are equal
231+
// TODO EF should do this automatically should it not?
232+
foreach (var concurrencyProp in concurrencyProperties)
233+
{
234+
// if is byte[] use array comparison, else equals().
235+
236+
var type = concurrencyProp.PropertyType;
237+
var obj1 = concurrencyProp.GetValue(entity1, null);
238+
var obj2 = concurrencyProp.GetValue(entity2, null);
239+
240+
if (
241+
(obj1 == null || obj2 == null) ||
242+
(type == typeof (byte[]) && !((byte[]) obj1).SequenceEqual((byte[]) obj2)) ||
243+
(type != typeof (byte[]) && !obj1.Equals(obj2))
244+
)
245+
{
246+
throw new DbUpdateConcurrencyException(String.Format("{0} failed optimistic concurrency", concurrencyProp.Name));
247+
}
248+
}
101249
}
102250
}
103-
}
251+
}

0 commit comments

Comments
 (0)