1- using RefactorThis . GraphDiff . Internal . Graph ;
21using System ;
32using System . Collections . Generic ;
43using System . Data . Entity ;
4+ using System . Data . Entity . Core ;
5+ using System . Data . Entity . Core . Metadata . Edm ;
6+ using System . Data . Entity . Core . Objects ;
57using System . Data . Entity . Infrastructure ;
68using System . Linq ;
7- using System . Linq . Expressions ;
89using System . Reflection ;
910
10- namespace RefactorThis . GraphDiff . Internal
11+ namespace RefactorThis . GraphDiff . Internal . Graph
1112{
12- /// <summary>
13- /// GraphDiff main access point.
14- /// </summary>
15- /// <typeparam name="T">The root agreggate type</typeparam>
16- internal class GraphDiffer < T > where T : class , new ( )
13+ internal class GraphNode
1714 {
18- private readonly GraphNode _root ;
15+ #region Fields, Properties Constructors
1916
20- public GraphDiffer ( GraphNode root )
17+ public GraphNode Parent { get ; private set ; }
18+ public Stack < GraphNode > Members { get ; private set ; }
19+
20+ protected readonly PropertyInfo Accessor ;
21+
22+ protected string IncludeString
2123 {
22- _root = root ;
24+ get
25+ {
26+ var ownIncludeString = Accessor != null ? Accessor . Name : null ;
27+ return Parent != null && Parent . IncludeString != null
28+ ? Parent . IncludeString + "." + ownIncludeString
29+ : ownIncludeString ;
30+ }
2331 }
2432
25- public T Merge ( DbContext context , T updating )
33+ public GraphNode ( )
2634 {
27- bool isAutoDetectEnabled = context . Configuration . AutoDetectChangesEnabled ;
28- try
29- {
30- // performance improvement for large graphs
31- context . Configuration . AutoDetectChangesEnabled = false ;
35+ Members = new Stack < GraphNode > ( ) ;
36+ }
3237
33- // Get our entity with all includes needed, or add
34- T persisted = GetOrAddPersistedEntity ( context , updating ) ;
38+ protected GraphNode ( GraphNode parent , PropertyInfo accessor )
39+ {
40+ Accessor = accessor ;
41+ Members = new Stack < GraphNode > ( ) ;
42+ Parent = parent ;
43+ }
3544
36- if ( context . Entry ( updating ) . State != EntityState . Detached )
37- {
38- throw new InvalidOperationException ( "GraphDiff supports detached entities only at this time. Please try AsNoTracking() or detach your entites before calling the UpdateGraph method" ) ;
39- }
45+ #endregion
4046
41- // Perform recursive update
42- _root . Update ( context , persisted , updating ) ;
47+ // overridden by different implementations
48+ public virtual void Update < T > ( DbContext context , T persisted , T updating ) where T : class , new ( )
49+ {
50+ UpdateValuesWithConcurrencyCheck ( context , updating , persisted ) ;
4351
44- return persisted ;
45- }
46- finally
52+ // Foreach branch perform recursive update
53+ foreach ( var member in Members )
4754 {
48- context . Configuration . AutoDetectChangesEnabled = isAutoDetectEnabled ;
55+ member . Update ( context , persisted , updating ) ;
4956 }
5057 }
5158
52- private T GetOrAddPersistedEntity ( DbContext context , T entity )
59+ protected T GetValue < T > ( object instance )
60+ {
61+ return ( T ) Accessor . GetValue ( instance , null ) ;
62+ }
63+
64+ protected void SetValue ( object instance , object value )
65+ {
66+ Accessor . SetValue ( instance , value , null ) ;
67+ }
68+
69+ protected static EntityKey CreateEntityKey ( IObjectContextAdapter context , object entity )
5370 {
5471 if ( entity == null )
5572 {
5673 throw new ArgumentNullException ( "entity" ) ;
5774 }
5875
59- var persisted = FindEntityMatching ( context , entity ) ;
76+ return context . ObjectContext . CreateEntityKey ( context . GetEntitySetName ( entity . GetType ( ) ) , entity ) ;
77+ }
78+
79+ internal void GetIncludeStrings ( DbContext context , List < string > includeStrings )
80+ {
81+ var ownIncludeString = IncludeString ;
82+ if ( ! string . IsNullOrEmpty ( ownIncludeString ) )
83+ {
84+ includeStrings . Add ( ownIncludeString ) ;
85+ }
86+
87+ includeStrings . AddRange ( GetRequiredNavigationPropertyIncludes ( context ) ) ;
88+
89+ foreach ( var member in Members )
90+ {
91+ member . GetIncludeStrings ( context , includeStrings ) ;
92+ }
93+ }
94+
95+ protected virtual IEnumerable < string > GetRequiredNavigationPropertyIncludes ( DbContext context )
96+ {
97+ return new string [ 0 ] ;
98+ }
99+
100+ protected static IEnumerable < string > GetRequiredNavigationPropertyIncludes ( DbContext context , Type entityType , string ownIncludeString )
101+ {
102+ return context . GetRequiredNavigationPropertiesForType ( entityType )
103+ . Select ( navigationProperty => ownIncludeString + "." + navigationProperty . Name ) ;
104+ }
105+
106+ protected static void AttachCyclicNavigationProperty ( IObjectContextAdapter context , object parent , object child )
107+ {
108+ if ( parent == null || child == null ) return ;
109+
110+ var parentType = ObjectContext . GetObjectType ( parent . GetType ( ) ) ;
111+ var childType = ObjectContext . GetObjectType ( child . GetType ( ) ) ;
112+
113+ var navigationProperties = context . GetNavigationPropertiesForType ( childType ) ;
114+
115+ var parentNavigationProperty = navigationProperties
116+ . Where ( navigation => navigation . TypeUsage . EdmType . Name == parentType . Name )
117+ . Select ( navigation => childType . GetProperty ( navigation . Name , BindingFlags . Instance | BindingFlags . NonPublic | BindingFlags . Public ) )
118+ . FirstOrDefault ( ) ;
119+
120+ if ( parentNavigationProperty != null )
121+ parentNavigationProperty . SetValue ( child , parent , null ) ;
122+ }
123+
124+ protected static void UpdateValuesWithConcurrencyCheck < T > ( DbContext context , T from , T to ) where T : class
125+ {
126+ if ( context . Entry ( to ) . State != EntityState . Added )
127+ {
128+ EnsureConcurrency ( context , from , to ) ;
129+ }
130+
131+ context . Entry ( to ) . CurrentValues . SetValues ( from ) ;
132+ }
133+
134+ protected static object AttachAndReloadAssociatedEntity ( DbContext context , object entity )
135+ {
136+ var localCopy = FindLocalByKey ( context , entity ) ;
137+ if ( localCopy != null ) return localCopy ;
60138
61- if ( persisted == null )
139+ if ( context . Entry ( entity ) . State == EntityState . Detached )
62140 {
63- // we are always working with 2 graphs, simply add a 'persisted' one if none exists,
64- // this ensures that only the changes we make within the bounds of the mapping are attempted.
65- persisted = new T ( ) ;
66- context . Set < T > ( ) . Add ( persisted ) ;
141+ var entityType = ObjectContext . GetObjectType ( entity . GetType ( ) ) ;
142+ var instance = CreateEmptyEntityWithKey ( context , entity ) ;
143+
144+ context . Set ( entityType ) . Attach ( instance ) ;
145+ context . Entry ( instance ) . Reload ( ) ;
146+
147+ AttachRequiredNavigationProperties ( context , entity , instance ) ;
148+ return instance ;
67149 }
68150
69- return persisted ;
151+ if ( GraphDiffConfiguration . ReloadAssociatedEntitiesWhenAttached )
152+ {
153+ context . Entry ( entity ) . Reload ( ) ;
154+ }
155+
156+ return entity ;
157+ }
158+
159+ private static object FindLocalByKey ( DbContext context , object entity )
160+ {
161+ var eType = ObjectContext . GetObjectType ( entity . GetType ( ) ) ;
162+ return context . Set ( eType ) . Local . OfType < object > ( ) . FirstOrDefault ( local => IsKeyIdentical ( context , local , entity ) ) ;
70163 }
71164
72- private T FindEntityMatching ( DbContext context , T entity )
165+ protected static void AttachRequiredNavigationProperties ( DbContext context , object updating , object persisted )
73166 {
74- var includeStrings = new List < string > ( ) ;
75- _root . GetIncludeStrings ( context , includeStrings ) ;
167+ var entityType = ObjectContext . GetObjectType ( updating . GetType ( ) ) ;
168+ foreach ( var navigationProperty in context . GetRequiredNavigationPropertiesForType ( updating . GetType ( ) ) )
169+ {
170+ var navigationPropertyInfo = entityType . GetProperty ( navigationProperty . Name , BindingFlags . Instance | BindingFlags . NonPublic | BindingFlags . Public ) ;
76171
77- // attach includes to IQueryable
78- var query = context . Set < T > ( ) . AsQueryable ( ) ;
79- query = includeStrings . Aggregate ( query , ( current , include ) => current . Include ( include ) ) ;
172+ var associatedEntity = navigationPropertyInfo . GetValue ( updating , null ) ;
173+ if ( associatedEntity != null )
174+ {
175+ associatedEntity = FindEntityByKey ( context , associatedEntity ) ;
176+ }
177+
178+ navigationPropertyInfo . SetValue ( persisted , associatedEntity , null ) ;
179+ }
180+ }
181+
182+ private static object FindEntityByKey ( DbContext context , object associatedEntity )
183+ {
184+ var associatedEntityType = ObjectContext . GetObjectType ( associatedEntity . GetType ( ) ) ;
185+ var keyFields = context . GetPrimaryKeyFieldsFor ( associatedEntityType ) ;
186+ var keys = keyFields . Select ( key => key . GetValue ( associatedEntity , null ) ) . ToArray ( ) ;
187+ return context . Set ( associatedEntityType ) . Find ( keys ) ;
188+ }
80189
81- // Run the find operation
82- return query . SingleOrDefault ( CreateKeyPredicateExpression ( context , entity ) ) ;
190+ protected static object CreateEmptyEntityWithKey ( IObjectContextAdapter context , object entity )
191+ {
192+ var instance = Activator . CreateInstance ( entity . GetType ( ) ) ;
193+ CopyPrimaryKeyFields ( context , entity , instance ) ;
194+ return instance ;
83195 }
84196
85- private static Expression < Func < T , bool > > CreateKeyPredicateExpression ( IObjectContextAdapter context , T entity )
197+ private static void CopyPrimaryKeyFields ( IObjectContextAdapter context , object from , object to )
86198 {
87- // get key properties of T
88- var keyProperties = context . GetPrimaryKeyFieldsFor ( typeof ( T ) ) . ToList ( ) ;
199+ var keyProperties = context . GetPrimaryKeyFieldsFor ( from . GetType ( ) ) . ToList ( ) ;
200+ foreach ( var keyProperty in keyProperties )
201+ keyProperty . SetValue ( to , keyProperty . GetValue ( from , null ) , null ) ;
202+ }
89203
90- ParameterExpression parameter = Expression . Parameter ( typeof ( T ) ) ;
91- Expression expression = CreateEqualsExpression ( entity , keyProperties [ 0 ] , parameter ) ;
92- for ( int i = 1 ; i < keyProperties . Count ; i ++ )
93- expression = Expression . And ( expression , CreateEqualsExpression ( entity , keyProperties [ i ] , parameter ) ) ;
204+ protected static bool IsKeyIdentical ( DbContext context , object newValue , object dbValue )
205+ {
206+ if ( newValue == null || dbValue == null ) return false ;
94207
95- return Expression . Lambda < Func < T , bool > > ( expression , parameter ) ;
208+ return CreateEntityKey ( context , newValue ) == CreateEntityKey ( context , dbValue ) ;
96209 }
97210
98- private static Expression CreateEqualsExpression ( object entity , PropertyInfo keyProperty , Expression parameter )
211+ private static void EnsureConcurrency < T > ( IObjectContextAdapter db , T entity1 , T entity2 )
99212 {
100- return Expression . Equal ( Expression . Property ( parameter , keyProperty ) , Expression . Constant ( keyProperty . GetValue ( entity , null ) , keyProperty . PropertyType ) ) ;
213+ // get concurrency properties of T
214+ var entityType = ObjectContext . GetObjectType ( entity1 . GetType ( ) ) ;
215+ var metadata = db . ObjectContext . MetadataWorkspace ;
216+
217+ var objType = metadata . GetEntityTypeByType ( entityType ) ;
218+
219+ // need internal string, code smells bad.. any better way to do this?
220+ var cTypeName = ( string ) objType . GetType ( )
221+ . GetProperty ( "CSpaceTypeName" , BindingFlags . Instance | BindingFlags . NonPublic )
222+ . GetValue ( objType , null ) ;
223+
224+ var conceptualType = metadata . GetItems < EntityType > ( DataSpace . CSpace ) . Single ( p => p . FullName == cTypeName ) ;
225+ var concurrencyProperties = conceptualType . Members
226+ . Where ( member => member . TypeUsage . Facets . Any ( facet => facet . Name == "ConcurrencyMode" && ( ConcurrencyMode ) facet . Value == ConcurrencyMode . Fixed ) )
227+ . Select ( member => entityType . GetProperty ( member . Name , BindingFlags . Instance | BindingFlags . NonPublic | BindingFlags . Public ) )
228+ . ToList ( ) ;
229+
230+ // Check if concurrency properties are equal
231+ // TODO EF should do this automatically should it not?
232+ foreach ( var concurrencyProp in concurrencyProperties )
233+ {
234+ // if is byte[] use array comparison, else equals().
235+
236+ var type = concurrencyProp . PropertyType ;
237+ var obj1 = concurrencyProp . GetValue ( entity1 , null ) ;
238+ var obj2 = concurrencyProp . GetValue ( entity2 , null ) ;
239+
240+ if (
241+ ( obj1 == null || obj2 == null ) ||
242+ ( type == typeof ( byte [ ] ) && ! ( ( byte [ ] ) obj1 ) . SequenceEqual ( ( byte [ ] ) obj2 ) ) ||
243+ ( type != typeof ( byte [ ] ) && ! obj1 . Equals ( obj2 ) )
244+ )
245+ {
246+ throw new DbUpdateConcurrencyException ( String . Format ( "{0} failed optimistic concurrency" , concurrencyProp . Name ) ) ;
247+ }
248+ }
101249 }
102250 }
103- }
251+ }
0 commit comments