Skip to content

Commit 04c62ff

Browse files
authored
good upload
good upload
1 parent bc73032 commit 04c62ff

File tree

1 file changed

+55
-203
lines changed

1 file changed

+55
-203
lines changed
Lines changed: 55 additions & 203 deletions
Original file line numberDiff line numberDiff line change
@@ -1,251 +1,103 @@
1+
using RefactorThis.GraphDiff.Internal.Graph;
12
using System;
23
using System.Collections.Generic;
34
using System.Data.Entity;
4-
using System.Data.Entity.Core;
5-
using System.Data.Entity.Core.Metadata.Edm;
6-
using System.Data.Entity.Core.Objects;
75
using System.Data.Entity.Infrastructure;
86
using System.Linq;
7+
using System.Linq.Expressions;
98
using System.Reflection;
109

11-
namespace RefactorThis.GraphDiff.Internal.Graph
10+
namespace RefactorThis.GraphDiff.Internal
1211
{
13-
internal class GraphNode
12+
/// <summary>
13+
/// GraphDiff main access point.
14+
/// </summary>
15+
/// <typeparam name="T">The root agreggate type</typeparam>
16+
internal class GraphDiffer<T> where T : class, new()
1417
{
15-
#region Fields, Properties Constructors
18+
private readonly GraphNode _root;
1619

17-
public GraphNode Parent { get; private set; }
18-
public Stack<GraphNode> Members { get; private set; }
19-
20-
protected readonly PropertyInfo Accessor;
21-
22-
protected string IncludeString
23-
{
24-
get
25-
{
26-
var ownIncludeString = Accessor != null ? Accessor.Name : null;
27-
return Parent != null && Parent.IncludeString != null
28-
? Parent.IncludeString + "." + ownIncludeString
29-
: ownIncludeString;
30-
}
31-
}
32-
33-
public GraphNode()
34-
{
35-
Members = new Stack<GraphNode>();
36-
}
37-
38-
protected GraphNode(GraphNode parent, PropertyInfo accessor)
20+
public GraphDiffer(GraphNode root)
3921
{
40-
Accessor = accessor;
41-
Members = new Stack<GraphNode>();
42-
Parent = parent;
22+
_root = root;
4323
}
4424

45-
#endregion
46-
47-
// overridden by different implementations
48-
public virtual void Update<T>(DbContext context, T persisted, T updating) where T : class, new()
25+
public T Merge(DbContext context, T updating)
4926
{
50-
UpdateValuesWithConcurrencyCheck(context, updating, persisted);
51-
52-
// Foreach branch perform recursive update
53-
foreach (var member in Members)
27+
bool isAutoDetectEnabled = context.Configuration.AutoDetectChangesEnabled;
28+
try
5429
{
55-
member.Update(context, persisted, updating);
56-
}
57-
}
30+
// performance improvement for large graphs
31+
context.Configuration.AutoDetectChangesEnabled = false;
5832

59-
protected T GetValue<T>(object instance)
60-
{
61-
return (T)Accessor.GetValue(instance, null);
62-
}
63-
64-
protected void SetValue(object instance, object value)
65-
{
66-
Accessor.SetValue(instance, value, null);
67-
}
33+
// Get our entity with all includes needed, or add
34+
T persisted = GetOrAddPersistedEntity(context, updating);
6835

69-
protected static EntityKey CreateEntityKey(IObjectContextAdapter context, object entity)
70-
{
71-
if (entity == null)
72-
{
73-
throw new ArgumentNullException("entity");
74-
}
36+
if (context.Entry(updating).State != EntityState.Detached)
37+
{
38+
throw new InvalidOperationException("GraphDiff supports detached entities only at this time. Please try AsNoTracking() or detach your entites before calling the UpdateGraph method");
39+
}
7540

76-
return context.ObjectContext.CreateEntityKey(context.GetEntitySetName(entity.GetType()), entity);
77-
}
41+
// Perform recursive update
42+
_root.Update(context, persisted, updating);
7843

79-
internal void GetIncludeStrings(DbContext context, List<string> includeStrings)
80-
{
81-
var ownIncludeString = IncludeString;
82-
if (!string.IsNullOrEmpty(ownIncludeString))
83-
{
84-
includeStrings.Add(ownIncludeString);
44+
return persisted;
8545
}
86-
87-
includeStrings.AddRange(GetRequiredNavigationPropertyIncludes(context));
88-
89-
foreach (var member in Members)
46+
finally
9047
{
91-
member.GetIncludeStrings(context, includeStrings);
48+
context.Configuration.AutoDetectChangesEnabled = isAutoDetectEnabled;
9249
}
9350
}
9451

95-
protected virtual IEnumerable<string> GetRequiredNavigationPropertyIncludes(DbContext context)
96-
{
97-
return new string[0];
98-
}
99-
100-
protected static IEnumerable<string> GetRequiredNavigationPropertyIncludes(DbContext context, Type entityType, string ownIncludeString)
101-
{
102-
return context.GetRequiredNavigationPropertiesForType(entityType)
103-
.Select(navigationProperty => ownIncludeString + "." + navigationProperty.Name);
104-
}
105-
106-
protected static void AttachCyclicNavigationProperty(IObjectContextAdapter context, object parent, object child)
52+
private T GetOrAddPersistedEntity(DbContext context, T entity)
10753
{
108-
if (parent == null || child == null) return;
109-
110-
var parentType = ObjectContext.GetObjectType(parent.GetType());
111-
var childType = ObjectContext.GetObjectType(child.GetType());
112-
113-
var navigationProperties = context.GetNavigationPropertiesForType(childType);
114-
115-
var parentNavigationProperty = navigationProperties
116-
.Where(navigation => navigation.TypeUsage.EdmType.Name == parentType.Name)
117-
.Select(navigation => childType.GetProperty(navigation.Name, BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.Public))
118-
.FirstOrDefault();
119-
120-
if (parentNavigationProperty != null)
121-
parentNavigationProperty.SetValue(child, parent, null);
122-
}
123-
124-
protected static void UpdateValuesWithConcurrencyCheck<T>(DbContext context, T from, T to) where T : class
125-
{
126-
if (context.Entry(to).State != EntityState.Added)
54+
if (entity == null)
12755
{
128-
EnsureConcurrency(context, from, to);
56+
throw new ArgumentNullException("entity");
12957
}
13058

131-
context.Entry(to).CurrentValues.SetValues(from);
132-
}
133-
134-
protected static object AttachAndReloadAssociatedEntity(DbContext context, object entity)
135-
{
136-
var localCopy = FindLocalByKey(context, entity);
137-
if (localCopy != null) return localCopy;
59+
var persisted = FindEntityMatching(context, entity);
13860

139-
if (context.Entry(entity).State == EntityState.Detached)
61+
if (persisted == null)
14062
{
141-
var entityType = ObjectContext.GetObjectType(entity.GetType());
142-
var instance = CreateEmptyEntityWithKey(context, entity);
143-
144-
context.Set(entityType).Attach(instance);
145-
context.Entry(instance).Reload();
146-
147-
AttachRequiredNavigationProperties(context, entity, instance);
148-
return instance;
63+
// we are always working with 2 graphs, simply add a 'persisted' one if none exists,
64+
// this ensures that only the changes we make within the bounds of the mapping are attempted.
65+
persisted = new T();
66+
context.Set<T>().Add(persisted);
14967
}
15068

151-
if (GraphDiffConfiguration.ReloadAssociatedEntitiesWhenAttached)
152-
{
153-
context.Entry(entity).Reload();
154-
}
155-
156-
return entity;
157-
}
158-
159-
private static object FindLocalByKey(DbContext context, object entity)
160-
{
161-
var eType = ObjectContext.GetObjectType(entity.GetType());
162-
return context.Set(eType).Local.OfType<object>().FirstOrDefault(local => IsKeyIdentical(context, local, entity));
69+
return persisted;
16370
}
16471

165-
protected static void AttachRequiredNavigationProperties(DbContext context, object updating, object persisted)
72+
private T FindEntityMatching(DbContext context, T entity)
16673
{
167-
var entityType = ObjectContext.GetObjectType(updating.GetType());
168-
foreach (var navigationProperty in context.GetRequiredNavigationPropertiesForType(updating.GetType()))
169-
{
170-
var navigationPropertyInfo = entityType.GetProperty(navigationProperty.Name, BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.Public);
74+
var includeStrings = new List<string>();
75+
_root.GetIncludeStrings(context, includeStrings);
17176

172-
var associatedEntity = navigationPropertyInfo.GetValue(updating, null);
173-
if (associatedEntity != null)
174-
{
175-
associatedEntity = FindEntityByKey(context, associatedEntity);
176-
}
177-
178-
navigationPropertyInfo.SetValue(persisted, associatedEntity, null);
179-
}
180-
}
181-
182-
private static object FindEntityByKey(DbContext context, object associatedEntity)
183-
{
184-
var associatedEntityType = ObjectContext.GetObjectType(associatedEntity.GetType());
185-
var keyFields = context.GetPrimaryKeyFieldsFor(associatedEntityType);
186-
var keys = keyFields.Select(key => key.GetValue(associatedEntity, null)).ToArray();
187-
return context.Set(associatedEntityType).Find(keys);
188-
}
77+
// attach includes to IQueryable
78+
var query = context.Set<T>().AsQueryable();
79+
query = includeStrings.Aggregate(query, (current, include) => current.Include(include));
18980

190-
protected static object CreateEmptyEntityWithKey(IObjectContextAdapter context, object entity)
191-
{
192-
var instance = Activator.CreateInstance(entity.GetType());
193-
CopyPrimaryKeyFields(context, entity, instance);
194-
return instance;
81+
// Run the find operation
82+
return query.SingleOrDefault(CreateKeyPredicateExpression(context, entity));
19583
}
19684

197-
private static void CopyPrimaryKeyFields(IObjectContextAdapter context, object from, object to)
85+
private static Expression<Func<T, bool>> CreateKeyPredicateExpression(IObjectContextAdapter context, T entity)
19886
{
199-
var keyProperties = context.GetPrimaryKeyFieldsFor(from.GetType()).ToList();
200-
foreach (var keyProperty in keyProperties)
201-
keyProperty.SetValue(to, keyProperty.GetValue(from, null), null);
202-
}
87+
// get key properties of T
88+
var keyProperties = context.GetPrimaryKeyFieldsFor(typeof(T)).ToList();
20389

204-
protected static bool IsKeyIdentical(DbContext context, object newValue, object dbValue)
205-
{
206-
if (newValue == null || dbValue == null) return false;
90+
ParameterExpression parameter = Expression.Parameter(typeof(T));
91+
Expression expression = CreateEqualsExpression(entity, keyProperties[0], parameter);
92+
for (int i = 1; i < keyProperties.Count; i++)
93+
expression = Expression.And(expression, CreateEqualsExpression(entity, keyProperties[i], parameter));
20794

208-
return CreateEntityKey(context, newValue) == CreateEntityKey(context, dbValue);
95+
return Expression.Lambda<Func<T, bool>>(expression, parameter);
20996
}
21097

211-
private static void EnsureConcurrency<T>(IObjectContextAdapter db, T entity1, T entity2)
98+
private static Expression CreateEqualsExpression(object entity, PropertyInfo keyProperty, Expression parameter)
21299
{
213-
// get concurrency properties of T
214-
var entityType = ObjectContext.GetObjectType(entity1.GetType());
215-
var metadata = db.ObjectContext.MetadataWorkspace;
216-
217-
var objType = metadata.GetEntityTypeByType(entityType);
218-
219-
// need internal string, code smells bad.. any better way to do this?
220-
var cTypeName = (string)objType.GetType()
221-
.GetProperty("CSpaceTypeName", BindingFlags.Instance | BindingFlags.NonPublic)
222-
.GetValue(objType, null);
223-
224-
var conceptualType = metadata.GetItems<EntityType>(DataSpace.CSpace).Single(p => p.FullName == cTypeName);
225-
var concurrencyProperties = conceptualType.Members
226-
.Where(member => member.TypeUsage.Facets.Any(facet => facet.Name == "ConcurrencyMode" && (ConcurrencyMode)facet.Value == ConcurrencyMode.Fixed))
227-
.Select(member => entityType.GetProperty(member.Name, BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.Public))
228-
.ToList();
229-
230-
// Check if concurrency properties are equal
231-
// TODO EF should do this automatically should it not?
232-
foreach (var concurrencyProp in concurrencyProperties)
233-
{
234-
// if is byte[] use array comparison, else equals().
235-
236-
var type = concurrencyProp.PropertyType;
237-
var obj1 = concurrencyProp.GetValue(entity1, null);
238-
var obj2 = concurrencyProp.GetValue(entity2, null);
239-
240-
if (
241-
(obj1 == null || obj2 == null) ||
242-
(type == typeof (byte[]) && !((byte[]) obj1).SequenceEqual((byte[]) obj2)) ||
243-
(type != typeof (byte[]) && !obj1.Equals(obj2))
244-
)
245-
{
246-
throw new DbUpdateConcurrencyException(String.Format("{0} failed optimistic concurrency", concurrencyProp.Name));
247-
}
248-
}
100+
return Expression.Equal(Expression.Property(parameter, keyProperty), Expression.Constant(keyProperty.GetValue(entity, null), keyProperty.PropertyType));
249101
}
250102
}
251103
}

0 commit comments

Comments
 (0)