1   /*
2    * SPDX-FileCopyrightText: none
3    * SPDX-License-Identifier: CC0-1.0
4    */
5   
6   package gov.nist.secauto.oscal.lib.profile.resolver.alter;
7   
8   import gov.nist.secauto.metaschema.core.util.CollectionUtil;
9   import gov.nist.secauto.metaschema.core.util.ObjectUtils;
10  import gov.nist.secauto.oscal.lib.model.Catalog;
11  import gov.nist.secauto.oscal.lib.model.CatalogGroup;
12  import gov.nist.secauto.oscal.lib.model.Control;
13  import gov.nist.secauto.oscal.lib.model.ControlPart;
14  import gov.nist.secauto.oscal.lib.model.Link;
15  import gov.nist.secauto.oscal.lib.model.Parameter;
16  import gov.nist.secauto.oscal.lib.model.Property;
17  import gov.nist.secauto.oscal.lib.model.control.catalog.ICatalogVisitor;
18  import gov.nist.secauto.oscal.lib.model.metadata.IProperty;
19  import gov.nist.secauto.oscal.lib.profile.resolver.ProfileResolutionEvaluationException;
20  
21  import java.util.Collection;
22  import java.util.Collections;
23  import java.util.EnumMap;
24  import java.util.EnumSet;
25  import java.util.Iterator;
26  import java.util.List;
27  import java.util.Locale;
28  import java.util.Map;
29  import java.util.Set;
30  import java.util.concurrent.ConcurrentHashMap;
31  import java.util.function.Function;
32  import java.util.function.Supplier;
33  
34  import edu.umd.cs.findbugs.annotations.NonNull;
35  import edu.umd.cs.findbugs.annotations.Nullable;
36  
37  public class RemoveVisitor implements ICatalogVisitor<Boolean, RemoveVisitor.Context> {
38    public enum TargetType {
39      PARAM("param", Parameter.class),
40      PROP("prop", Property.class),
41      LINK("link", Link.class),
42      PART("part", ControlPart.class);
43  
44      @NonNull
45      private static final Map<Class<?>, TargetType> CLASS_TO_TYPE;
46      @NonNull
47      private static final Map<String, TargetType> NAME_TO_TYPE;
48      @NonNull
49      private final String fieldName;
50      @NonNull
51      private final Class<?> clazz;
52  
53      static {
54        {
55          Map<Class<?>, TargetType> map = new ConcurrentHashMap<>();
56          for (TargetType type : values()) {
57            map.put(type.getClazz(), type);
58          }
59          CLASS_TO_TYPE = CollectionUtil.unmodifiableMap(map);
60        }
61  
62        {
63          Map<String, TargetType> map = new ConcurrentHashMap<>();
64          for (TargetType type : values()) {
65            map.put(type.fieldName(), type);
66          }
67          NAME_TO_TYPE = CollectionUtil.unmodifiableMap(map);
68        }
69      }
70  
71      /**
72       * Get the target type associated with the provided {@code clazz}.
73       *
74       * @param clazz
75       *          the class to identify the target type for
76       * @return the associated target type or {@code null} if the class is not
77       *         associated with a target type
78       */
79      @Nullable
80      public static TargetType forClass(@NonNull Class<?> clazz) {
81        Class<?> target = clazz;
82        TargetType retval;
83        // recurse over parent classes to find a match
84        do {
85          retval = CLASS_TO_TYPE.get(target);
86        } while (retval == null && (target = target.getSuperclass()) != null);
87        return retval;
88      }
89  
90      /**
91       * Get the target type associated with the provided field {@code name}.
92       *
93       * @param name
94       *          the field name to identify the target type for
95       * @return the associated target type or {@code null} if the name is not
96       *         associated with a target type
97       */
98      @Nullable
99      public static TargetType forFieldName(@Nullable String name) {
100       return name == null ? null : NAME_TO_TYPE.get(name);
101     }
102 
103     TargetType(@NonNull String fieldName, @NonNull Class<?> clazz) {
104       this.fieldName = fieldName;
105       this.clazz = clazz;
106     }
107 
108     /**
109      * Get the field name associated with the target type.
110      *
111      * @return the name
112      */
113     public String fieldName() {
114       return fieldName;
115     }
116 
117     /**
118      * Get the bound class associated with the target type.
119      *
120      * @return the class
121      */
122     public Class<?> getClazz() {
123       return clazz;
124     }
125 
126   }
127 
128   @NonNull
129   private static final RemoveVisitor INSTANCE = new RemoveVisitor();
130 
131   private static final Map<TargetType, Set<TargetType>> APPLICABLE_TARGETS;
132 
133   static {
134     APPLICABLE_TARGETS = new EnumMap<>(TargetType.class);
135     APPLICABLE_TARGETS.put(TargetType.PARAM, Set.of(TargetType.PROP, TargetType.LINK));
136     APPLICABLE_TARGETS.put(TargetType.PART, Set.of(TargetType.PART, TargetType.PROP, TargetType.LINK));
137   }
138 
139   private static Set<TargetType> getApplicableTypes(@NonNull TargetType type) {
140     return APPLICABLE_TARGETS.getOrDefault(type, CollectionUtil.emptySet());
141   }
142 
143   private static <T> boolean handle(
144       @NonNull TargetType itemType,
145       @NonNull Supplier<? extends Collection<T>> supplier,
146       @Nullable Function<T, Boolean> handler,
147       @NonNull Context context) {
148 
149     boolean handleChildren = !Collections.disjoint(context.getTargetItemTypes(), getApplicableTypes(itemType));
150     boolean retval = false;
151     if (context.isMatchingType(itemType)) {
152       // if the item type is applicable, attempt to remove any items
153       Iterator<T> iter = supplier.get().iterator();
154       while (iter.hasNext()) {
155         T item = iter.next();
156 
157         if (item == null || context.isApplicableTo(item)) {
158           iter.remove();
159           retval = true;
160           // ignore removed items and their children
161         } else if (handler != null && handleChildren) {
162           // handle child items since they are applicable to the search criteria
163           retval = retval || handler.apply(item);
164         }
165       }
166     } else if (handleChildren && handler != null) {
167       for (T item : supplier.get()) {
168         if (item != null) {
169           retval = retval || handler.apply(item);
170         }
171       }
172     }
173     return retval;
174   }
175 
176   /**
177    * Apply the remove directive.
178    *
179    * @param control
180    *          the control target
181    * @param objectName
182    *          the name flag of a matching node to remove
183    * @param objectClass
184    *          the class flag of a matching node to remove
185    * @param objectId
186    *          the id flag of a matching node to remove
187    * @param objectNamespace
188    *          the namespace flag of a matching node to remove
189    * @param itemType
190    *          the type of a matching node to remove
191    * @return {@code true} if the modification was made or {@code false} otherwise
192    * @throws ProfileResolutionEvaluationException
193    *           if a processing error occurred during profile resolution
194    */
195   public static boolean remove(
196       @NonNull Control control,
197       @Nullable String objectName,
198       @Nullable String objectClass,
199       @Nullable String objectId,
200       @Nullable String objectNamespace,
201       @Nullable TargetType itemType) {
202     return INSTANCE.visitControl(
203         control,
204         new Context(objectName, objectClass, objectId, objectNamespace, itemType));
205   }
206 
207   @Override
208   public Boolean visitCatalog(Catalog catalog, Context context) {
209     // not required
210     throw new UnsupportedOperationException("not needed");
211   }
212 
213   @Override
214   public Boolean visitGroup(CatalogGroup group, Context context) {
215     // not required
216     throw new UnsupportedOperationException("not needed");
217   }
218 
219   @NonNull
220   private static <T> List<T> modifiableListOrEmpty(@Nullable List<T> list) {
221     return list == null ? CollectionUtil.emptyList() : list;
222   }
223 
224   @Override
225   public Boolean visitControl(Control control, Context context) {
226     assert context != null;
227 
228     // visit params
229     boolean retval = handle(
230         TargetType.PARAM,
231         () -> modifiableListOrEmpty(control.getParams()),
232         child -> visitParameter(ObjectUtils.notNull(child), context),
233         context);
234 
235     // visit props
236     retval = retval || handle(
237         TargetType.PROP,
238         () -> modifiableListOrEmpty(control.getProps()),
239         null,
240         context);
241 
242     // visit links
243     retval = retval || handle(
244         TargetType.LINK,
245         () -> modifiableListOrEmpty(control.getLinks()),
246         null,
247         context);
248 
249     return retval || handle(
250         TargetType.PART,
251         () -> modifiableListOrEmpty(control.getParts()),
252         child -> visitPart(child, context),
253         context);
254   }
255 
256   @Override
257   public Boolean visitParameter(Parameter parameter, Context context) {
258     assert context != null;
259 
260     // visit props
261     boolean retval = handle(
262         TargetType.PROP,
263         () -> modifiableListOrEmpty(parameter.getProps()),
264         null,
265         context);
266 
267     return retval || handle(
268         TargetType.LINK,
269         () -> modifiableListOrEmpty(parameter.getLinks()),
270         null,
271         context);
272   }
273 
274   /**
275    * Visit the control part.
276    *
277    * @param part
278    *          the bound part object
279    * @param context
280    *          the visitor context
281    * @return {@code true} if the removal was applied or {@code false} otherwise
282    */
283   public boolean visitPart(ControlPart part, Context context) {
284     assert context != null;
285 
286     // visit props
287     boolean retval = handle(
288         TargetType.PROP,
289         () -> modifiableListOrEmpty(part.getProps()),
290         null,
291         context);
292 
293     // visit links
294     retval = retval || handle(
295         TargetType.LINK,
296         () -> modifiableListOrEmpty(part.getLinks()),
297         null,
298         context);
299 
300     return retval || handle(
301         TargetType.PART,
302         () -> modifiableListOrEmpty(part.getParts()),
303         child -> visitPart(child, context),
304         context);
305   }
306 
307   static final class Context {
308     /**
309      * Types with an "name" flag.
310      */
311     @NonNull
312     private static final Set<TargetType> NAME_TYPES = ObjectUtils.notNull(
313         Set.of(TargetType.PART, TargetType.PROP));
314     /**
315      * Types with an "class" flag.
316      */
317     @NonNull
318     private static final Set<TargetType> CLASS_TYPES = ObjectUtils.notNull(
319         Set.of(TargetType.PARAM, TargetType.PART, TargetType.PROP));
320     /**
321      * Types with an "id" flag.
322      */
323     @NonNull
324     private static final Set<TargetType> ID_TYPES = ObjectUtils.notNull(
325         Set.of(TargetType.PARAM, TargetType.PART));
326     /**
327      * Types with an "ns" flag.
328      */
329     @NonNull
330     private static final Set<TargetType> NAMESPACE_TYPES = ObjectUtils.notNull(
331         Set.of(TargetType.PART, TargetType.PROP));
332 
333     @Nullable
334     private final String objectName;
335     @Nullable
336     private final String objectClass;
337     @Nullable
338     private final String objectId;
339     @Nullable
340     private final String objectNamespace;
341     @NonNull
342     private final Set<TargetType> targetItemTypes;
343 
344     private static boolean filterTypes(
345         @NonNull Set<TargetType> effectiveTypes,
346         @NonNull String criteria,
347         @NonNull Set<TargetType> allowedTypes,
348         @Nullable String value,
349         @Nullable TargetType itemType) {
350       boolean retval = false;
351       if (value != null) {
352         retval = effectiveTypes.retainAll(allowedTypes);
353         if (itemType != null && !allowedTypes.contains(itemType)) {
354           throw new ProfileResolutionEvaluationException(
355               String.format("%s='%s' is not supported for items of type '%s'",
356                   criteria,
357                   value,
358                   itemType.fieldName()));
359         }
360       }
361       return retval;
362     }
363 
364     private Context(
365         @Nullable String objectName,
366         @Nullable String objectClass,
367         @Nullable String objectId,
368         @Nullable String objectNamespace,
369         @Nullable TargetType itemType) {
370 
371       // determine the set of effective item types to search for
372       // this helps with short-circuit searching for parts of the graph that cannot
373       // match
374       @NonNull Set<TargetType> targetItemTypes = ObjectUtils.notNull(EnumSet.allOf(TargetType.class));
375       filterTypes(targetItemTypes, "by-name", NAME_TYPES, objectName, itemType);
376       filterTypes(targetItemTypes, "by-class", CLASS_TYPES, objectClass, itemType);
377       filterTypes(targetItemTypes, "by-id", ID_TYPES, objectId, itemType);
378       filterTypes(targetItemTypes, "by-ns", NAMESPACE_TYPES, objectNamespace, itemType);
379 
380       if (itemType != null) {
381         targetItemTypes.retainAll(Set.of(itemType));
382       }
383 
384       if (targetItemTypes.isEmpty()) {
385         throw new ProfileResolutionEvaluationException("The filter matches no available item types");
386       }
387 
388       this.objectName = objectName;
389       this.objectClass = objectClass;
390       this.objectId = objectId;
391       this.objectNamespace = objectNamespace;
392       this.targetItemTypes = CollectionUtil.unmodifiableSet(targetItemTypes);
393     }
394 
395     @Nullable
396     public String getObjectName() {
397       return objectName;
398     }
399 
400     @Nullable
401     public String getObjectClass() {
402       return objectClass;
403     }
404 
405     @Nullable
406     public String getObjectId() {
407       return objectId;
408     }
409 
410     @NonNull
411     public Set<TargetType> getTargetItemTypes() {
412       return targetItemTypes;
413     }
414 
415     public boolean isMatchingType(@NonNull TargetType type) {
416       return getTargetItemTypes().contains(type);
417     }
418 
419     @Nullable
420     public String getObjectNamespace() {
421       return objectNamespace;
422     }
423 
424     private static boolean checkValue(@Nullable String actual, @Nullable String expected) {
425       return expected == null || expected.equals(actual);
426     }
427 
428     public boolean isApplicableTo(@NonNull Object obj) {
429       TargetType objectType = TargetType.forClass(obj.getClass());
430 
431       boolean retval = objectType != null && getTargetItemTypes().contains(objectType);
432       if (retval) {
433         assert objectType != null;
434 
435         // check other criteria
436         String actualName = null;
437         String actualClass = null;
438         String actualId = null;
439         String actualNamespace = null;
440 
441         switch (objectType) {
442         case PARAM: {
443           Parameter param = (Parameter) obj;
444           actualClass = param.getClazz();
445           actualId = param.getId();
446           break;
447         }
448         case PROP: {
449           Property prop = (Property) obj;
450           actualName = prop.getName();
451           actualClass = prop.getClazz();
452           actualNamespace = prop.getNs() == null ? IProperty.OSCAL_NAMESPACE.toString() : prop.getNs().toString();
453           break;
454         }
455         case PART: {
456           ControlPart part = (ControlPart) obj;
457           actualName = part.getName();
458           actualClass = part.getClazz();
459           actualId = part.getId() == null ? null : part.getId();
460           actualNamespace = part.getNs() == null ? IProperty.OSCAL_NAMESPACE.toString() : part.getNs().toString();
461           break;
462         }
463         case LINK:
464           // do nothing
465           break;
466         default:
467           throw new UnsupportedOperationException(objectType.name().toLowerCase(Locale.ROOT));
468         }
469 
470         retval = checkValue(actualName, getObjectName())
471             && checkValue(actualClass, getObjectClass())
472             && checkValue(actualId, getObjectId())
473             && checkValue(actualNamespace, getObjectNamespace());
474       }
475       return retval;
476     }
477   }
478 }