1   /*
2    * SPDX-FileCopyrightText: none
3    * SPDX-License-Identifier: CC0-1.0
4    */
5   
6   package gov.nist.secauto.oscal.lib.profile.resolver.alter;
7   
8   import gov.nist.secauto.metaschema.core.datatype.markup.MarkupLine;
9   import gov.nist.secauto.metaschema.core.util.CollectionUtil;
10  import gov.nist.secauto.metaschema.core.util.CustomCollectors;
11  import gov.nist.secauto.metaschema.core.util.ObjectUtils;
12  import gov.nist.secauto.oscal.lib.model.Catalog;
13  import gov.nist.secauto.oscal.lib.model.CatalogGroup;
14  import gov.nist.secauto.oscal.lib.model.Control;
15  import gov.nist.secauto.oscal.lib.model.ControlPart;
16  import gov.nist.secauto.oscal.lib.model.Link;
17  import gov.nist.secauto.oscal.lib.model.Parameter;
18  import gov.nist.secauto.oscal.lib.model.Property;
19  import gov.nist.secauto.oscal.lib.model.control.catalog.ICatalogVisitor;
20  import gov.nist.secauto.oscal.lib.profile.resolver.ProfileResolutionEvaluationException;
21  
22  import java.util.Collections;
23  import java.util.EnumMap;
24  import java.util.EnumSet;
25  import java.util.LinkedList;
26  import java.util.List;
27  import java.util.ListIterator;
28  import java.util.Locale;
29  import java.util.Map;
30  import java.util.Set;
31  import java.util.concurrent.ConcurrentHashMap;
32  import java.util.function.Consumer;
33  import java.util.function.Function;
34  import java.util.function.Supplier;
35  
36  import edu.umd.cs.findbugs.annotations.NonNull;
37  import edu.umd.cs.findbugs.annotations.Nullable;
38  
39  public class AddVisitor implements ICatalogVisitor<Boolean, AddVisitor.Context> {
40    public enum TargetType {
41      CONTROL("control", Control.class),
42      PARAM("param", Parameter.class),
43      PART("part", ControlPart.class);
44  
45      @NonNull
46      private static final Map<Class<?>, TargetType> CLASS_TO_TYPE;
47      @NonNull
48      private static final Map<String, TargetType> NAME_TO_TYPE;
49      @NonNull
50      private final String fieldName;
51      @NonNull
52      private final Class<?> clazz;
53  
54      static {
55        {
56          Map<Class<?>, TargetType> map = new ConcurrentHashMap<>();
57          for (TargetType type : values()) {
58            map.put(type.getClazz(), type);
59          }
60          CLASS_TO_TYPE = CollectionUtil.unmodifiableMap(map);
61        }
62  
63        {
64          Map<String, TargetType> map = new ConcurrentHashMap<>();
65          for (TargetType type : values()) {
66            map.put(type.fieldName(), type);
67          }
68          NAME_TO_TYPE = CollectionUtil.unmodifiableMap(map);
69        }
70      }
71  
72      /**
73       * Get the target type associated with the provided {@code clazz}.
74       *
75       * @param clazz
76       *          the class to identify the target type for
77       * @return the associated target type or {@code null} if the class is not
78       *         associated with a target type
79       */
80      @Nullable
81      public static TargetType forClass(@NonNull Class<?> clazz) {
82        Class<?> target = clazz;
83        TargetType retval;
84        // recurse over parent classes to find a match
85        do {
86          retval = CLASS_TO_TYPE.get(target);
87        } while (retval == null && (target = target.getSuperclass()) != null);
88        return retval;
89      }
90  
91      /**
92       * Get the target type associated with the provided field {@code name}.
93       *
94       * @param name
95       *          the field name to identify the target type for
96       * @return the associated target type or {@code null} if the name is not
97       *         associated with a target type
98       */
99      @Nullable
100     public static TargetType forFieldName(@Nullable String name) {
101       return name == null ? null : NAME_TO_TYPE.get(name);
102     }
103 
104     TargetType(@NonNull String fieldName, @NonNull Class<?> clazz) {
105       this.fieldName = fieldName;
106       this.clazz = clazz;
107     }
108 
109     /**
110      * Get the field name associated with the target type.
111      *
112      * @return the name
113      */
114     public String fieldName() {
115       return fieldName;
116     }
117 
118     /**
119      * Get the bound class associated with the target type.
120      *
121      * @return the class
122      */
123     public Class<?> getClazz() {
124       return clazz;
125     }
126   }
127 
128   public enum Position {
129     BEFORE,
130     AFTER,
131     STARTING,
132     ENDING;
133 
134     @NonNull
135     private static final Map<String, Position> NAME_TO_POSITION;
136 
137     static {
138       Map<String, Position> map = new ConcurrentHashMap<>();
139       for (Position position : values()) {
140         map.put(position.name().toLowerCase(Locale.ROOT), position);
141       }
142       NAME_TO_POSITION = CollectionUtil.unmodifiableMap(map);
143     }
144 
145     /**
146      * Get the position associated with the provided {@code name}.
147      *
148      * @param name
149      *          the name to identify the position for
150      * @return the associated position or {@code null} if the name is not associated
151      *         with a position
152      */
153     @Nullable
154     public static Position forName(@Nullable String name) {
155       return name == null ? null : NAME_TO_POSITION.get(name);
156     }
157   }
158 
159   @NonNull
160   private static final AddVisitor INSTANCE = new AddVisitor();
161   private static final Map<TargetType, Set<TargetType>> APPLICABLE_TARGETS;
162 
163   static {
164     APPLICABLE_TARGETS = new EnumMap<>(TargetType.class);
165     APPLICABLE_TARGETS.put(TargetType.CONTROL, Set.of(TargetType.CONTROL, TargetType.PARAM, TargetType.PART));
166     APPLICABLE_TARGETS.put(TargetType.PARAM, Set.of(TargetType.PARAM));
167     APPLICABLE_TARGETS.put(TargetType.PART, Set.of(TargetType.PART));
168   }
169 
170   private static Set<TargetType> getApplicableTypes(@NonNull TargetType type) {
171     return APPLICABLE_TARGETS.getOrDefault(type, CollectionUtil.emptySet());
172   }
173 
174   /**
175    * Apply the add directive.
176    *
177    * @param control
178    *          the control target
179    * @param position
180    *          the position to apply the content or {@code null}
181    * @param byId
182    *          the identifier of the target or {@code null}
183    * @param title
184    *          a title to set
185    * @param params
186    *          parameters to add
187    * @param props
188    *          properties to add
189    * @param links
190    *          links to add
191    * @param parts
192    *          parts to add
193    * @return {@code true} if the modification was made or {@code false} otherwise
194    * @throws ProfileResolutionEvaluationException
195    *           if a processing error occurred during profile resolution
196    */
197   public static boolean add(
198       @NonNull Control control,
199       @Nullable Position position,
200       @Nullable String byId,
201       @Nullable MarkupLine title,
202       @NonNull List<Parameter> params,
203       @NonNull List<Property> props,
204       @NonNull List<Link> links,
205       @NonNull List<ControlPart> parts) {
206     return INSTANCE.visitControl(
207         control,
208         new Context(
209             control,
210             position == null ? Position.ENDING : position,
211             byId,
212             title,
213             params,
214             props,
215             links,
216             parts));
217   }
218 
219   @Override
220   public Boolean visitCatalog(Catalog catalog, Context context) {
221     // not required
222     throw new UnsupportedOperationException("not needed");
223   }
224 
225   @Override
226   public Boolean visitGroup(CatalogGroup group, Context context) {
227     // not required
228     throw new UnsupportedOperationException("not needed");
229   }
230 
231   /**
232    * If the add applies to the current object, then apply the child objects.
233    * <p>
234    * An add applies if:
235    * <ol>
236    * <li>the {@code targetItem} supports all of the children</li>
237    * <li>the context matches if:
238    * <ul>
239    * <li>the target item's id matches the "by-id"; or</li>
240    * <li>the "by-id" is not defined and the target item is the control matching
241    * the target context</li>
242    * </ul>
243    * </li>
244    * </ol>
245    *
246    * @param <T>
247    *          the type of the {@code targetItem}
248    * @param targetItem
249    *          the current target to process
250    * @param titleConsumer
251    *          a consumer to apply a title to or {@code null} if the object has no
252    *          title field
253    * @param paramsSupplier
254    *          a supplier for the child {@link Parameter} collection
255    * @param propsSupplier
256    *          a supplier for the child {@link Property} collection
257    * @param linksSupplier
258    *          a supplier for the child {@link Link} collection
259    * @param partsSupplier
260    *          a supplier for the child {@link ControlPart} collection
261    * @param context
262    *          the add context
263    * @return {@code true} if a modification was made or {@code false} otherwise
264    */
265   private static <T> boolean handleCurrent(
266       @NonNull T targetItem,
267       @Nullable Consumer<MarkupLine> titleConsumer,
268       @Nullable Supplier<? extends List<Parameter>> paramsSupplier,
269       @Nullable Supplier<? extends List<Property>> propsSupplier,
270       @Nullable Supplier<? extends List<Link>> linksSupplier,
271       @Nullable Supplier<? extends List<ControlPart>> partsSupplier,
272       @NonNull Context context) {
273     boolean retval = false;
274     Position position = context.getPosition();
275     if (context.appliesTo(targetItem) && !context.isSequenceTargeted(targetItem)) {
276       // the target item is the target of the add
277       MarkupLine newTitle = context.getTitle();
278       if (newTitle != null) {
279         assert titleConsumer != null;
280         titleConsumer.accept(newTitle);
281       }
282 
283       handleCollection(position, context.getParams(), paramsSupplier);
284       handleCollection(position, context.getProps(), propsSupplier);
285       handleCollection(position, context.getLinks(), linksSupplier);
286       handleCollection(position, context.getParts(), partsSupplier);
287       retval = true;
288     }
289     return retval;
290   }
291 
292   private static <T> void handleCollection(
293       @NonNull Position position,
294       @NonNull List<T> newItems,
295       @Nullable Supplier<? extends List<T>> originalCollectionSupplier) {
296     if (originalCollectionSupplier != null) {
297       List<T> oldItems = originalCollectionSupplier.get();
298       if (!newItems.isEmpty()) {
299         if (Position.STARTING.equals(position)) {
300           oldItems.addAll(0, newItems);
301         } else { // ENDING
302           oldItems.addAll(newItems);
303         }
304       }
305     }
306   }
307 
308   // private static <T> void handleChild(
309   // @NonNull TargetType itemType,
310   // @NonNull Supplier<? extends List<T>> collectionSupplier,
311   // @Nullable Consumer<T> handler,
312   // @NonNull Context context) {
313   // boolean handleChildren = !Collections.disjoint(context.getTargetItemTypes(),
314   // getApplicableTypes(itemType));
315   // if (handleChildren && handler != null) {
316   // // if the child item type is applicable and there is a handler, iterate over
317   // children
318   // Iterator<T> iter = collectionSupplier.get().iterator();
319   // while (iter.hasNext()) {
320   // T item = iter.next();
321   // if (item != null) {
322   // handler.accept(item);
323   // }
324   // }
325   // }
326   // }
327 
328   private static <T> boolean handleChild(
329       @NonNull TargetType itemType,
330       @NonNull Supplier<? extends List<T>> originalCollectionSupplier,
331       @NonNull Supplier<? extends List<T>> newItemsSupplier,
332       @Nullable Function<T, Boolean> handler,
333       @NonNull Context context) {
334 
335     // determine if this child type can match
336     boolean isItemTypeMatch = context.isMatchingType(itemType);
337 
338     Set<TargetType> applicableTypes = getApplicableTypes(itemType);
339     boolean descendChild = handler != null && !Collections.disjoint(context.getTargetItemTypes(), applicableTypes);
340 
341     boolean retval = false;
342     if (isItemTypeMatch || descendChild) {
343       // if the item type is applicable, attempt to match by id
344       List<T> collection = originalCollectionSupplier.get();
345       ListIterator<T> iter = collection.listIterator();
346       boolean deferred = false;
347       while (iter.hasNext()) {
348         T item = ObjectUtils.requireNonNull(iter.next());
349 
350         if (isItemTypeMatch && context.appliesTo(item) && context.isSequenceTargeted(item)) {
351           // if id match, inject the new items into the collection
352           switch (context.getPosition()) {
353           case AFTER: {
354             newItemsSupplier.get().forEach(iter::add);
355             retval = true;
356             break;
357           }
358           case BEFORE: {
359             iter.previous();
360             List<T> adds = newItemsSupplier.get();
361             adds.forEach(iter::add);
362             item = iter.next();
363             retval = true;
364             break;
365           }
366           case STARTING:
367           case ENDING:
368             deferred = true;
369             break;
370           default:
371             throw new UnsupportedOperationException(context.getPosition().name().toLowerCase(Locale.ROOT));
372           }
373         }
374 
375         if (descendChild) {
376           assert handler != null;
377 
378           // handle child items since they are applicable to the search criteria
379           retval = retval || handler.apply(item);
380         }
381       }
382 
383       if (deferred) {
384         List<T> newItems = newItemsSupplier.get();
385         if (Position.ENDING.equals(context.getPosition())) {
386           collection.addAll(newItems);
387           retval = true;
388         } else if (Position.STARTING.equals(context.getPosition())) {
389           collection.addAll(0, newItems);
390           retval = true;
391         }
392       }
393     }
394     return retval;
395   }
396 
397   @Override
398   public Boolean visitControl(Control control, Context context) {
399     assert context != null;
400 
401     if (control.getParams() == null) {
402       control.setParams(new LinkedList<>());
403     }
404 
405     if (control.getProps() == null) {
406       control.setProps(new LinkedList<>());
407     }
408 
409     if (control.getLinks() == null) {
410       control.setLinks(new LinkedList<>());
411     }
412 
413     if (control.getParts() == null) {
414       control.setParts(new LinkedList<>());
415     }
416 
417     boolean retval = handleCurrent(
418         control,
419         control::setTitle,
420         control::getParams,
421         control::getProps,
422         control::getLinks,
423         control::getParts,
424         context);
425 
426     // visit params
427     retval = retval || handleChild(
428         TargetType.PARAM,
429         control::getParams,
430         context::getParams,
431         child -> visitParameter(ObjectUtils.notNull(child), context),
432         context);
433 
434     // visit parts
435     retval = retval || handleChild(
436         TargetType.PART,
437         control::getParts,
438         context::getParts,
439         child -> visitPart(child, context),
440         context);
441 
442     // visit control children
443     for (Control childControl : CollectionUtil.listOrEmpty(control.getControls())) {
444       Set<TargetType> applicableTypes = getApplicableTypes(TargetType.CONTROL);
445       if (!Collections.disjoint(context.getTargetItemTypes(), applicableTypes)) {
446         retval = retval || visitControl(ObjectUtils.requireNonNull(childControl), context);
447       }
448     }
449     return retval;
450   }
451 
452   @Override
453   public Boolean visitParameter(Parameter parameter, Context context) {
454     assert context != null;
455     if (parameter.getProps() == null) {
456       parameter.setProps(new LinkedList<>());
457     }
458 
459     if (parameter.getLinks() == null) {
460       parameter.setLinks(new LinkedList<>());
461     }
462 
463     return handleCurrent(
464         parameter,
465         null,
466         null,
467         parameter::getProps,
468         parameter::getLinks,
469         null,
470         context);
471   }
472 
473   /**
474    * Visit the control part.
475    *
476    * @param part
477    *          the bound part object
478    * @param context
479    *          the visitor context
480    * @return {@code true} if the removal was applied or {@code false} otherwise
481    */
482   public boolean visitPart(ControlPart part, Context context) {
483     assert context != null;
484     if (part.getProps() == null) {
485       part.setProps(new LinkedList<>());
486     }
487 
488     if (part.getLinks() == null) {
489       part.setLinks(new LinkedList<>());
490     }
491 
492     if (part.getParts() == null) {
493       part.setParts(new LinkedList<>());
494     }
495 
496     boolean retval = handleCurrent(
497         part,
498         null,
499         null,
500         part::getProps,
501         part::getLinks,
502         part::getParts,
503         context);
504 
505     return retval || handleChild(
506         TargetType.PART,
507         part::getParts,
508         context::getParts,
509         child -> visitPart(child, context),
510         context);
511   }
512 
513   static final class Context {
514     @NonNull
515     private static final Set<TargetType> TITLE_TYPES = ObjectUtils.notNull(
516         Set.of(TargetType.CONTROL, TargetType.PART));
517     @NonNull
518     private static final Set<TargetType> PARAM_TYPES = ObjectUtils.notNull(
519         Set.of(TargetType.CONTROL, TargetType.PARAM));
520     @NonNull
521     private static final Set<TargetType> PROP_TYPES = ObjectUtils.notNull(
522         Set.of(TargetType.CONTROL, TargetType.PARAM, TargetType.PART));
523     @NonNull
524     private static final Set<TargetType> LINK_TYPES = ObjectUtils.notNull(
525         Set.of(TargetType.CONTROL, TargetType.PARAM, TargetType.PART));
526     @NonNull
527     private static final Set<TargetType> PART_TYPES = ObjectUtils.notNull(
528         Set.of(TargetType.CONTROL, TargetType.PART));
529 
530     @NonNull
531     private final Control control;
532     @NonNull
533     private final Position position;
534     @Nullable
535     private final String byId;
536     @Nullable
537     private final MarkupLine title;
538     @NonNull
539     private final List<Parameter> params;
540     @NonNull
541     private final List<Property> props;
542     @NonNull
543     private final List<Link> links;
544     @NonNull
545     private final List<ControlPart> parts;
546     @NonNull
547     private final Set<TargetType> targetItemTypes;
548 
549     public Context(
550         @NonNull Control control,
551         @NonNull Position position,
552         @Nullable String byId,
553         @Nullable MarkupLine title,
554         @NonNull List<Parameter> params,
555         @NonNull List<Property> props,
556         @NonNull List<Link> links,
557         @NonNull List<ControlPart> parts) {
558 
559       Set<TargetType> targetItemTypes = ObjectUtils.notNull(EnumSet.allOf(TargetType.class));
560 
561       List<String> additionObjects = new LinkedList<>();
562 
563       boolean sequenceTarget = true;
564       if (title != null) {
565         targetItemTypes.retainAll(TITLE_TYPES);
566         additionObjects.add("title");
567         sequenceTarget = false;
568       }
569 
570       if (!params.isEmpty()) {
571         targetItemTypes.retainAll(PARAM_TYPES);
572         additionObjects.add("param");
573       }
574 
575       if (!props.isEmpty()) {
576         targetItemTypes.retainAll(PROP_TYPES);
577         additionObjects.add("prop");
578         sequenceTarget = false;
579       }
580 
581       if (!links.isEmpty()) {
582         targetItemTypes.retainAll(LINK_TYPES);
583         additionObjects.add("link");
584         sequenceTarget = false;
585       }
586 
587       if (!parts.isEmpty()) {
588         targetItemTypes.retainAll(PART_TYPES);
589         additionObjects.add("part");
590       }
591 
592       if (Position.BEFORE.equals(position) || Position.AFTER.equals(position)) {
593         if (!sequenceTarget) {
594           throw new ProfileResolutionEvaluationException(
595               "When using position before or after, one collection of parameters or parts can be specified."
596                   + " Other additions must not be used.");
597         }
598         if (sequenceTarget && !params.isEmpty() && parts.isEmpty()) {
599           targetItemTypes.retainAll(Set.of(TargetType.PARAM));
600         } else if (sequenceTarget && !parts.isEmpty() && params.isEmpty()) {
601           targetItemTypes.retainAll(Set.of(TargetType.PART));
602         } else {
603           throw new ProfileResolutionEvaluationException(
604               "When using position before or after, only one collection of parameters or parts can be specified.");
605         }
606       }
607 
608       if (targetItemTypes.isEmpty()) {
609         throw new ProfileResolutionEvaluationException("No parent object supports the requested objects to add: " +
610             additionObjects.stream().collect(CustomCollectors.joiningWithOxfordComma("or")));
611       }
612 
613       this.control = control;
614       this.position = position;
615       this.byId = byId;
616       this.title = title;
617       this.params = params;
618       this.props = props;
619       this.links = links;
620       this.parts = parts;
621       this.targetItemTypes = CollectionUtil.unmodifiableSet(targetItemTypes);
622     }
623 
624     public Control getControl() {
625       return control;
626     }
627 
628     @NonNull
629     public Position getPosition() {
630       return position;
631     }
632 
633     @Nullable
634     public String getById() {
635       return byId;
636     }
637 
638     @Nullable
639     public MarkupLine getTitle() {
640       return title;
641     }
642 
643     @NonNull
644     public List<Parameter> getParams() {
645       return params;
646     }
647 
648     @NonNull
649     public List<Property> getProps() {
650       return props;
651     }
652 
653     @NonNull
654     public List<Link> getLinks() {
655       return links;
656     }
657 
658     @NonNull
659     public List<ControlPart> getParts() {
660       return parts;
661     }
662 
663     @NonNull
664     public Set<TargetType> getTargetItemTypes() {
665       return targetItemTypes;
666     }
667 
668     public boolean isMatchingType(@NonNull TargetType type) {
669       return getTargetItemTypes().contains(type);
670     }
671 
672     public <T> boolean isSequenceTargeted(T targetItem) {
673       TargetType objectType = TargetType.forClass(targetItem.getClass());
674       return (Position.BEFORE.equals(position) || Position.AFTER.equals(position))
675           && (TargetType.PARAM.equals(objectType) && isMatchingType(TargetType.PARAM)
676               || TargetType.PART.equals(objectType) && isMatchingType(TargetType.PART));
677     }
678 
679     /**
680      * Determine if the provided {@code obj} is the target of the add.
681      *
682      * @param obj
683      *          the current object
684      * @return {@code true} if the current object applies or {@code false} otherwise
685      */
686     public boolean appliesTo(@NonNull Object obj) {
687       TargetType objectType = TargetType.forClass(obj.getClass());
688 
689       boolean retval = objectType != null && isMatchingType(objectType);
690       if (retval) {
691         assert objectType != null;
692 
693         // check other criteria
694         String actualId = null;
695 
696         switch (objectType) {
697         case CONTROL: {
698           Control control = (Control) obj;
699           actualId = control.getId();
700           break;
701         }
702         case PARAM: {
703           Parameter param = (Parameter) obj;
704           actualId = param.getId();
705           break;
706         }
707         case PART: {
708           ControlPart part = (ControlPart) obj;
709           actualId = part.getId() == null ? null : part.getId();
710           break;
711         }
712         default:
713           throw new UnsupportedOperationException(objectType.fieldName());
714         }
715 
716         String byId = getById();
717         if (getById() == null && TargetType.CONTROL.equals(objectType)) {
718           retval = getControl().equals(obj);
719         } else {
720           retval = byId != null && byId.equals(actualId);
721         }
722       }
723       return retval;
724     }
725   }
726 }